22#ifndef COMMUNICATION_MPI_CALLBACKS
23#define COMMUNICATION_MPI_CALLBACKS
42#include <boost/mpi/collectives/broadcast.hpp>
43#include <boost/mpi/communicator.hpp>
44#include <boost/mpi/environment.hpp>
45#include <boost/mpi/packed_iarchive.hpp>
67using is_allowed_argument =
68 std::integral_constant<bool,
69 not(std::is_pointer_v<T> ||
70 (!std::is_const_v<std::remove_reference_t<T>> &&
71 std::is_lvalue_reference_v<T>))>;
73template <
class... Args>
74using are_allowed_arguments =
88template <
class F,
class... Args>
89auto invoke(F f, boost::mpi::packed_iarchive &ia) {
90 static_assert(are_allowed_arguments<Args...>::value,
91 "Pointers and non-const references are not allowed as "
92 "arguments for callbacks.");
96 std::tuple<std::remove_const_t<std::remove_reference_t<Args>>...>
params;
103 return std::apply(f, std::as_const(
params));
113struct callback_concept_t {
119 virtual void operator()(boost::mpi::communicator
const &,
120 boost::mpi::packed_iarchive &)
const = 0;
121 virtual ~callback_concept_t() =
default;
130template <
class F,
class... Args>
131struct callback_void_t final :
public callback_concept_t {
134 callback_void_t(callback_void_t
const &) =
delete;
135 callback_void_t(callback_void_t &&) =
delete;
137 template <
class FRef>
138 explicit callback_void_t(FRef &&f) : m_f(std::forward<FRef>(f)) {}
139 void operator()(boost::mpi::communicator
const &,
140 boost::mpi::packed_iarchive &ia)
const override {
141 detail::invoke<F, Args...>(m_f, ia);
145template <
class F,
class R,
class... Args>
struct FunctorTypes {
146 using functor_type = F;
147 using return_type = R;
148 using argument_types = std::tuple<Args...>;
151template <
class C,
class R,
class... Args>
152auto functor_types_impl(R (C::*)(Args...) const) {
153 return FunctorTypes<C, R, Args...>{};
158 decltype(functor_types_impl(&std::remove_reference_t<F>::operator()));
160template <
class CRef,
class C,
class R,
class... Args>
161auto make_model_impl(CRef &&c, FunctorTypes<C, R, Args...>) {
162 return std::make_unique<callback_void_t<C, Args...>>(std::forward<CRef>(c));
171template <
typename F>
auto make_model(F &&f) {
172 return make_model_impl(std::forward<F>(f), functor_types<F>{});
178template <
class... Args>
auto make_model(
void (*f_ptr)(Args...)) {
179 return std::make_unique<callback_void_t<void (*)(Args...), Args...>>(f_ptr);
201 template <
typename F,
class = std::enable_if_t<std::is_same_v<
202 typename detail::functor_types<F>::argument_types,
203 std::tuple<Args...>>>>
205 : m_id(cb->add(std::forward<F>(f))), m_cb(std::move(cb)) {}
214 std::shared_ptr<MpiCallbacks> m_cb;
224 template <
class... ArgRef>
229 std::is_void_v<
decltype(std::declval<void (*)(Args...)>()(
230 std::forward<ArgRef>(args)...))>> {
232 m_cb->call(m_id, std::forward<ArgRef>(args)...);
240 int id()
const {
return m_id; }
248 static auto &static_callbacks() {
250 std::pair<void (*)(), std::unique_ptr<detail::callback_concept_t>>>
258 std::shared_ptr<boost::mpi::environment> mpi_env)
259 : m_comm(std::move(
comm)), m_mpi_env(std::move(mpi_env)) {
261 m_callback_map.add(
nullptr);
263 for (
auto &kv : static_callbacks()) {
264 m_func_ptr_to_id[kv.first] = m_callback_map.add(kv.second.get());
270 if (m_comm.rank() == 0) {
290 template <
typename F>
auto add(F &&f) {
291 m_callbacks.emplace_back(detail::make_model(std::forward<F>(f)));
292 return m_callback_map.add(m_callbacks.back().get());
304 template <
class... Args>
void add(
void (*fp)(Args...)) {
305 m_callbacks.emplace_back(detail::make_model(fp));
306 const int id = m_callback_map.add(m_callbacks.back().get());
307 m_func_ptr_to_id[
reinterpret_cast<void (*)()
>(fp)] =
id;
318 template <
class... Args>
static void add_static(
void (*fp)(Args...)) {
319 static_callbacks().emplace_back(
reinterpret_cast<void (*)()
>(fp),
320 detail::make_model(fp));
332 void remove(
int id) {
333 std::erase_if(m_callbacks, [ptr = m_callback_map[
id]](
auto const &e) {
334 return e.get() == ptr;
336 m_callback_map.remove(
id);
351 template <
class... Args>
void call(
int id, Args &&...args)
const {
352 if (m_comm.rank() != 0) {
353 throw std::logic_error(
"Callbacks can only be invoked on rank 0.");
356 assert(m_callback_map.find(
id) != m_callback_map.end() &&
357 "m_callback_map and m_func_ptr_to_id disagree");
360 boost::mpi::packed_oarchive oa(m_comm);
365 std::forward_as_tuple(std::forward<Args>(args)...));
367 boost::mpi::broadcast(m_comm, oa, 0);
381 template <
class... Args,
class... ArgRef>
382 auto call(
void (*fp)(Args...), ArgRef &&...args) const ->
384 std::enable_if_t<std::is_void_v<decltype(fp(args...))>> {
385 const int id = m_func_ptr_to_id.at(
reinterpret_cast<void (*)()
>(fp));
387 call(
id, std::forward<ArgRef>(args)...);
400 template <
class... Args,
class... ArgRef>
401 auto call_all(
void (*fp)(Args...), ArgRef &&...args) const ->
403 std::enable_if_t<std::is_void_v<decltype(fp(args...))>> {
420 boost::mpi::packed_iarchive ia(m_comm);
421 boost::mpi::broadcast(m_comm, ia, 0);
426 if (request == LOOP_ABORT) {
430 m_callback_map[request]->operator()(m_comm, ia);
444 boost::mpi::communicator
const &
comm()
const {
return m_comm; }
454 static constexpr int LOOP_ABORT = 0;
459 boost::mpi::communicator m_comm;
464 std::shared_ptr<boost::mpi::environment> m_mpi_env;
469 std::vector<std::unique_ptr<detail::callback_concept_t>> m_callbacks;
480 std::unordered_map<void (*)(),
int> m_func_ptr_to_id;
483template <
class... Args>
510#define REGISTER_CALLBACK(cb) \
511 namespace Communication { \
512 static ::Communication::RegisterCallback register_##cb(&(cb)); \
Keep an enumerated list of T objects, managed by the class.
RAII handle for a callback.
auto operator()(ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(std::declval< void(*)(Args...)>()(std::forward< ArgRef >(args)...))> >
Call the callback managed by this handle.
CallbackHandle(CallbackHandle &&rhs) noexcept=default
CallbackHandle(CallbackHandle const &)=delete
CallbackHandle & operator=(CallbackHandle &&rhs) noexcept=default
CallbackHandle(std::shared_ptr< MpiCallbacks > cb, F &&f)
CallbackHandle & operator=(CallbackHandle const &)=delete
The interface of the MPI callback mechanism.
auto call_all(void(*fp)(Args...), ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(fp(args...))> >
Call a callback on all nodes.
auto call(void(*fp)(Args...), ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(fp(args...))> >
Call a callback on worker nodes.
MpiCallbacks(boost::mpi::communicator comm, std::shared_ptr< boost::mpi::environment > mpi_env)
void add(void(*fp)(Args...))
Add a new callback.
std::shared_ptr< boost::mpi::environment > share_mpi_env() const
boost::mpi::communicator const & comm() const
The boost mpi communicator used by this instance.
void abort_loop()
Abort the MPI loop.
static void add_static(void(*fp)(Args...))
Add a new callback.
MpiCallbacks & operator=(MpiCallbacks const &)=delete
MpiCallbacks(MpiCallbacks const &)=delete
void loop() const
Start the MPI loop.
Helper class to add callbacks before main.
RegisterCallback()=delete
RegisterCallback(void(*cb)(Args...))
Container for objects that are identified by a numeric id.
void for_each(F &&f, Tuple &t)
static SteepestDescentParameters params
Currently active steepest descent instance.
Algorithms for tuple-like inhomogeneous containers.