ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
MpiCallbacks.hpp
Go to the documentation of this file.
1/*
2 * Copyright (C) 2010-2022 The ESPResSo project
3 * Copyright (C) 2002,2003,2004,2005,2006,2007,2008,2009,2010
4 * Max-Planck-Institute for Polymer Research, Theory Group
5 *
6 * This file is part of ESPResSo.
7 *
8 * ESPResSo is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * ESPResSo is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#pragma once
23
24/**
25 * @file
26 *
27 * @ref Communication::MpiCallbacks manages MPI communication using a
28 * visitor pattern. The program runs on the head node and is responsible
29 * for calling callback functions on the worker nodes when necessary,
30 * e.g. to broadcast global variables or run an algorithm in parallel.
31 *
32 * Callbacks are registered on the head node as function pointers via
33 * the @ref REGISTER_CALLBACK. The visitor pattern allows using arbitrary
34 * function signatures.
35 */
36
38
39#include <boost/mpi/collectives/broadcast.hpp>
40#include <boost/mpi/communicator.hpp>
41#include <boost/mpi/environment.hpp>
42#include <boost/mpi/packed_iarchive.hpp>
43
44#include <cassert>
45#include <memory>
46#include <tuple>
47#include <type_traits>
48#include <utility>
49#include <vector>
50
51namespace Communication {
52
53namespace detail {
54/**
55 * @brief Check if a type can be used as a callback argument.
56 *
57 * This checks is a type can be a parameter type for a MPI callback.
58 * Not allowed are pointers and non-const references, as output
59 * parameters can not work across ranks.
60 */
61template <class T>
62using is_allowed_argument =
63 std::integral_constant<bool,
64 not(std::is_pointer_v<T> ||
65 (!std::is_const_v<std::remove_reference_t<T>> &&
66 std::is_lvalue_reference_v<T>))>;
67
68/**
69 * @brief Invoke a callable with arguments from an mpi buffer.
70 *
71 * @tparam F A Callable that can be called with Args as parameters.
72 * @tparam Args Pack of arguments for @p F
73 *
74 * @param f Functor to be called
75 * @param ia Buffer to extract the parameters from
76 *
77 * @return Return value of calling @p f.
78 */
79template <class F, class... Args>
80auto invoke(F f, boost::mpi::packed_iarchive &ia) {
81 static_assert(std::conjunction_v<is_allowed_argument<Args>...>,
82 "Pointers and non-const references are not allowed as "
83 "arguments for callbacks.");
84
85 /* This is the local receive buffer for the parameters. We have to strip
86 away const so we can actually deserialize into it. */
87 std::tuple<std::remove_const_t<std::remove_reference_t<Args>>...> params;
88 std::apply([&ia](auto &&...e) { ((ia >> e), ...); }, params);
89
90 /* We add const here, so that parameters can only be by value
91 or const reference. Output parameters on callbacks are not
92 sensible because the changes are not propagated back, so
93 we make sure this does not compile. */
94 return std::apply(f, std::as_const(params));
95}
96
97/**
98 * @brief Type-erased interface for callbacks.
99 *
100 * This encapsulates the signature of the callback
101 * and the parameter transfer, so that it can be
102 * called without any type information on the parameters.
103 */
104struct callback_concept_t {
105 /**
106 * @brief Execute the callback.
107 *
108 * Unpack parameters for this callback, and then call it.
109 */
110 virtual void operator()(boost::mpi::communicator const &,
111 boost::mpi::packed_iarchive &) const = 0;
112 virtual ~callback_concept_t() = default;
113};
114
115/**
116 * @brief Callback without a return value.
117 *
118 * This is an implementation of a callback for a specific callable
119 * @p F and a set of arguments to call it with.
120 */
121template <class F, class... Args>
122struct callback_void_t final : public callback_concept_t {
123 F m_f;
124
125 callback_void_t(callback_void_t const &) = delete;
126 callback_void_t(callback_void_t &&) = delete;
127
128 template <class FRef>
129 explicit callback_void_t(FRef &&f) : m_f(std::forward<FRef>(f)) {}
130 void operator()(boost::mpi::communicator const &,
131 boost::mpi::packed_iarchive &ia) const override {
132 detail::invoke<F, Args...>(m_f, ia);
133 }
134};
135
136template <class F, class R, class... Args> struct FunctorTypes {
137 using functor_type = F;
138 using return_type = R;
139 using argument_types = std::tuple<Args...>;
140};
141
142template <class C, class R, class... Args>
143auto functor_types_impl(R (C::*)(Args...) const) {
144 return FunctorTypes<C, R, Args...>{};
145}
146
147template <class F>
148using functor_types =
149 decltype(functor_types_impl(&std::remove_reference_t<F>::operator()));
150
151template <class CRef, class C, class R, class... Args>
152auto make_model_impl(CRef &&c, FunctorTypes<C, R, Args...>) {
153 return std::make_unique<callback_void_t<C, Args...>>(std::forward<CRef>(c));
154}
155
156/**
157 * @brief Make a @ref callback_model_t for a functor or lambda.
158 *
159 * The signature is deduced from F::operator() const, which has
160 * to exist and can not be overloaded.
161 */
162template <typename F> auto make_model(F &&f) {
163 return make_model_impl(std::forward<F>(f), functor_types<F>{});
164}
165
166/**
167 * @brief Make a @ref callback_model_t for a function pointer.
168 */
169template <class... Args> auto make_model(void (*f_ptr)(Args...)) {
170 return std::make_unique<callback_void_t<void (*)(Args...), Args...>>(f_ptr);
171}
172} // namespace detail
173
174/**
175 * @brief The interface of the MPI callback mechanism.
176 */
178public:
179 /**
180 * @brief RAII handle for a callback.
181 *
182 * This is what the client gets for registering a
183 * dynamic (= not function pointer) callback.
184 * It manages the lifetime of the callback handle
185 * needed to call it. The handle has a type derived
186 * from the signature of the callback, which makes
187 * it possible to do static type checking on the
188 * arguments.
189 */
190 template <class... Args> class CallbackHandle {
191 public:
192 template <typename F>
193 requires(std::is_same_v<typename detail::functor_types<F>::argument_types,
194 std::tuple<Args...>>)
195 CallbackHandle(std::shared_ptr<MpiCallbacks> cb, F &&f)
196 : m_id(cb->add(std::forward<F>(f))), m_cb(std::move(cb)) {}
197
199 CallbackHandle(CallbackHandle &&rhs) noexcept = default;
201 CallbackHandle &operator=(CallbackHandle &&rhs) noexcept = default;
202
203 private:
204 int m_id;
205 std::shared_ptr<MpiCallbacks> m_cb;
206
207 public:
208 /**
209 * @brief Call the callback managed by this handle.
210 *
211 * The arguments are passed to the remote callees, it
212 * must be possible to call the function with the provided
213 * arguments, otherwise this will not compile.
214 */
215 template <class... ArgRef>
216 auto operator()(ArgRef &&...args) const
217 /* Enable if a hypothetical function with signature void(Args..)
218 * could be called with the provided arguments. */
219 requires(std::is_void_v<decltype(std::declval<void (*)(Args...)>()(
220 std::forward<ArgRef>(args)...))>)
221 {
222 if (m_cb)
223 m_cb->call(m_id, std::forward<ArgRef>(args)...);
224 }
225
227 if (m_cb)
228 m_cb->remove(m_id);
229 }
230
231 int id() const { return m_id; }
232 };
233
234 /* Avoid accidental copy, leads to mpi deadlock or split brain */
235 MpiCallbacks(MpiCallbacks const &) = delete;
237
238private:
239 static auto &static_callbacks() {
240 static std::vector<
241 std::pair<void (*)(), std::unique_ptr<detail::callback_concept_t>>>
242 callbacks;
243
244 return callbacks;
245 }
246
247public:
248 MpiCallbacks(boost::mpi::communicator comm,
249 std::shared_ptr<boost::mpi::environment> mpi_env)
250 : m_comm(std::move(comm)), m_mpi_env(std::move(mpi_env)) {
251 /* Add a dummy at id 0 for loop abort. */
252 m_callback_map.add(nullptr);
253
254 for (auto &[fp, handle] : static_callbacks()) {
255 m_func_ptr_to_id[fp] = m_callback_map.add(handle.get());
256 }
257 }
258
260 /* Release the clients on exit */
261 if (m_comm.rank() == 0) {
262 try {
263 abort_loop();
264 } catch (...) { // NOLINT(bugprone-empty-catch)
265 }
266 }
267 /* MPI_Finalize is unsafe if there are pending non-blocking operations */
268 m_comm.barrier();
269 m_mpi_env.reset();
270 }
271
272private:
273 /**
274 * @brief Add a new callback.
275 *
276 * Add a new callback to the system. This is a collective
277 * function that must be run on all nodes.
278 *
279 * @tparam F An object with a const call operator.
280 *
281 * @param f The callback function to add.
282 * @return A handle with which the callback can be called.
283 */
284 template <typename F> auto add(F &&f) {
285 m_callbacks.emplace_back(detail::make_model(std::forward<F>(f)));
286 return m_callback_map.add(m_callbacks.back().get());
287 }
288
289public:
290 /**
291 * @brief Add a new callback.
292 *
293 * Add a new callback to the system. This is a collective
294 * function that must be run on all nodes.
295 *
296 * @param fp Pointer to the static callback function to add.
297 */
298 template <class... Args> void add(void (*fp)(Args...)) {
299 m_callbacks.emplace_back(detail::make_model(fp));
300 const int id = m_callback_map.add(m_callbacks.back().get());
301 m_func_ptr_to_id[reinterpret_cast<void (*)()>(fp)] = id;
302 }
303
304 /**
305 * @brief Add a new callback.
306 *
307 * Add a new callback to the system. This is a collective
308 * function that must be run on all nodes.
309 *
310 * @param fp Pointer to the static callback function to add.
311 */
312 template <class... Args> static void add_static(void (*fp)(Args...)) {
313 static_callbacks().emplace_back(reinterpret_cast<void (*)()>(fp),
314 detail::make_model(fp));
315 }
316
317private:
318 /**
319 * @brief Remove callback.
320 *
321 * Remove the callback id from the callback list.
322 * This is a collective call that must be run on all nodes.
323 *
324 * @param id Identifier of the callback to remove.
325 */
326 void remove(int id) {
327 std::erase_if(m_callbacks, [ptr = m_callback_map[id]](auto const &e) {
328 return e.get() == ptr;
329 });
330 m_callback_map.remove(id);
331 }
332
333private:
334 /**
335 * @brief call a callback.
336 *
337 * Call the callback id.
338 * The method can only be called on the head node
339 * and has the prerequisite that the other nodes are
340 * in the MPI loop.
341 *
342 * @param id The callback to call.
343 * @param args Arguments for the callback.
344 */
345 template <class... Args> void call(int id, Args &&...args) const {
346 if (m_comm.rank() != 0) {
347 throw std::logic_error("Callbacks can only be invoked on rank 0.");
348 }
349
350 assert(m_callback_map.find(id) != m_callback_map.end() &&
351 "m_callback_map and m_func_ptr_to_id disagree");
352
353 /* Send request to worker nodes */
354 boost::mpi::packed_oarchive oa(m_comm);
355 oa << id;
356
357 /* Pack the arguments into a packed mpi buffer. */
358 std::apply([&oa](auto &&...e) { ((oa << e), ...); },
359 std::forward_as_tuple(std::forward<Args>(args)...));
360
361 boost::mpi::broadcast(m_comm, oa, 0);
362 }
363
364public:
365 /**
366 * @brief Call a callback on worker nodes.
367 *
368 * The callback is **not** called on the head node.
369 *
370 * This method can only be called on the head node.
371 *
372 * @param fp Pointer to the function to call.
373 * @param args Arguments for the callback.
374 */
375 template <class... Args, class... ArgRef>
376 auto call(void (*fp)(Args...), ArgRef &&...args) const
377 /* enable only if fp can be called with the provided arguments */
378 requires(std::is_void_v<decltype(fp(args...))>)
379 {
380 const int id = m_func_ptr_to_id.at(reinterpret_cast<void (*)()>(fp));
381
382 call(id, std::forward<ArgRef>(args)...);
383 }
384
385 /**
386 * @brief Call a callback on all nodes.
387 *
388 * This calls a callback on all nodes, including the head node.
389 *
390 * This method can only be called on the head node.
391 *
392 * @param fp Pointer to the function to call.
393 * @param args Arguments for the callback.
394 */
395 template <class... Args, class... ArgRef>
396 auto call_all(void (*fp)(Args...), ArgRef &&...args) const
397 /* enable only if fp can be called with the provided arguments */
398 requires(std::is_void_v<decltype(fp(args...))>)
399 {
400 call(fp, args...);
401 fp(args...);
402 }
403
404 /**
405 * @brief Start the MPI loop.
406 *
407 * This is the callback loop for the worker nodes. They block
408 * on the MPI call and wait for a new callback request
409 * coming from the head node.
410 * This should be run on the worker nodes and must be running
411 * so that the head node can issue call().
412 */
413 void loop() const {
414 for (;;) {
415 /* Communicate callback id and parameters */
416 boost::mpi::packed_iarchive ia(m_comm);
417 boost::mpi::broadcast(m_comm, ia, 0);
418
419 int request;
420 ia >> request;
421
422 if (request == LOOP_ABORT) {
423 break;
424 }
425 /* Call the callback */
426 m_callback_map[request]->operator()(m_comm, ia);
427 }
428 }
429
430 /**
431 * @brief Abort the MPI loop.
432 *
433 * Make the worker nodes exit the MPI loop.
434 */
435 void abort_loop() { call(LOOP_ABORT); }
436
437 /**
438 * @brief The boost mpi communicator used by this instance
439 */
440 boost::mpi::communicator const &comm() const { return m_comm; }
441
442private:
443 /**
444 * @brief Id for the @ref abort_loop. Has to be 0.
445 */
446 static constexpr int LOOP_ABORT = 0;
447
448 /**
449 * The MPI communicator used for the callbacks.
450 */
451 boost::mpi::communicator m_comm;
452
453 /**
454 * The MPI environment used for the callbacks.
455 */
456 std::shared_ptr<boost::mpi::environment> m_mpi_env;
457
458 /**
459 * Internal storage for the callback functions.
460 */
461 std::vector<std::unique_ptr<detail::callback_concept_t>> m_callbacks;
462
463 /**
464 * Map of ids to callbacks.
465 */
467
468 /**
469 * Mapping of function pointers to ids, so static callbacks can be
470 * called by their pointer.
471 */
472 std::unordered_map<void (*)(), int> m_func_ptr_to_id;
473};
474
475template <class... Args>
477
478/**
479 * @brief Helper class to add callbacks before main.
480 *
481 * Should not be used directly, but via @ref REGISTER_CALLBACK.
482 */
484
485public:
487
488 template <class... Args> explicit RegisterCallback(void (*cb)(Args...)) {
490 }
491};
492} /* namespace Communication */
493
494/**
495 * @brief Register a static callback without return value.
496 *
497 * This registers a function as an mpi callback.
498 * The macro should be used at global scope.
499 *
500 * @param cb A function
501 */
502#define REGISTER_CALLBACK(cb) \
503 namespace Communication { \
504 static ::Communication::RegisterCallback register_##cb(&(cb)); \
505 }
Keep an enumerated list of T objects, managed by the class.
CallbackHandle(CallbackHandle &&rhs) noexcept=default
auto operator()(ArgRef &&...args) const
Call the callback managed by this handle.
CallbackHandle(CallbackHandle const &)=delete
CallbackHandle(std::shared_ptr< MpiCallbacks > cb, F &&f)
CallbackHandle & operator=(CallbackHandle &&rhs) noexcept=default
CallbackHandle & operator=(CallbackHandle const &)=delete
The interface of the MPI callback mechanism.
auto call_all(void(*fp)(Args...), ArgRef &&...args) const
Call a callback on all nodes.
MpiCallbacks(boost::mpi::communicator comm, std::shared_ptr< boost::mpi::environment > mpi_env)
void add(void(*fp)(Args...))
Add a new callback.
boost::mpi::communicator const & comm() const
The boost mpi communicator used by this instance.
void abort_loop()
Abort the MPI loop.
static void add_static(void(*fp)(Args...))
Add a new callback.
MpiCallbacks & operator=(MpiCallbacks const &)=delete
auto call(void(*fp)(Args...), ArgRef &&...args) const
Call a callback on worker nodes.
MpiCallbacks(MpiCallbacks const &)=delete
void loop() const
Start the MPI loop.
Helper class to add callbacks before main.
RegisterCallback(void(*cb)(Args...))
Container for objects that are identified by a numeric id.
STL namespace.
static SteepestDescentParameters params
Currently active steepest descent instance.