22#ifndef COMMUNICATION_MPI_CALLBACKS
23#define COMMUNICATION_MPI_CALLBACKS
40#include <boost/mpi/collectives/broadcast.hpp>
41#include <boost/mpi/communicator.hpp>
42#include <boost/mpi/environment.hpp>
43#include <boost/mpi/packed_iarchive.hpp>
65using is_allowed_argument =
66 std::integral_constant<bool,
67 not(std::is_pointer_v<T> ||
68 (!std::is_const_v<std::remove_reference_t<T>> &&
69 std::is_lvalue_reference_v<T>))>;
82template <
class F,
class... Args>
83auto invoke(F f, boost::mpi::packed_iarchive &ia) {
84 static_assert(std::conjunction_v<is_allowed_argument<Args>...>,
85 "Pointers and non-const references are not allowed as "
86 "arguments for callbacks.");
90 std::tuple<std::remove_const_t<std::remove_reference_t<Args>>...>
params;
91 std::apply([&ia](
auto &&...e) { ((ia >> e), ...); },
params);
97 return std::apply(f, std::as_const(
params));
107struct callback_concept_t {
113 virtual void operator()(boost::mpi::communicator
const &,
114 boost::mpi::packed_iarchive &)
const = 0;
115 virtual ~callback_concept_t() =
default;
124template <
class F,
class... Args>
125struct callback_void_t final :
public callback_concept_t {
128 callback_void_t(callback_void_t
const &) =
delete;
129 callback_void_t(callback_void_t &&) =
delete;
131 template <
class FRef>
132 explicit callback_void_t(FRef &&f) : m_f(std::forward<FRef>(f)) {}
133 void operator()(boost::mpi::communicator
const &,
134 boost::mpi::packed_iarchive &ia)
const override {
135 detail::invoke<F, Args...>(m_f, ia);
139template <
class F,
class R,
class... Args>
struct FunctorTypes {
140 using functor_type = F;
141 using return_type = R;
142 using argument_types = std::tuple<Args...>;
145template <
class C,
class R,
class... Args>
146auto functor_types_impl(R (C::*)(Args...) const) {
147 return FunctorTypes<C, R, Args...>{};
152 decltype(functor_types_impl(&std::remove_reference_t<F>::operator()));
154template <
class CRef,
class C,
class R,
class... Args>
155auto make_model_impl(CRef &&c, FunctorTypes<C, R, Args...>) {
156 return std::make_unique<callback_void_t<C, Args...>>(std::forward<CRef>(c));
165template <
typename F>
auto make_model(F &&f) {
166 return make_model_impl(std::forward<F>(f), functor_types<F>{});
172template <
class... Args>
auto make_model(
void (*f_ptr)(Args...)) {
173 return std::make_unique<callback_void_t<void (*)(Args...), Args...>>(f_ptr);
195 template <
typename F,
class = std::enable_if_t<std::is_same_v<
196 typename detail::functor_types<F>::argument_types,
197 std::tuple<Args...>>>>
199 : m_id(cb->add(std::forward<F>(f))), m_cb(std::move(cb)) {}
208 std::shared_ptr<MpiCallbacks> m_cb;
218 template <
class... ArgRef>
223 std::is_void_v<
decltype(std::declval<void (*)(Args...)>()(
224 std::forward<ArgRef>(args)...))>> {
226 m_cb->call(m_id, std::forward<ArgRef>(args)...);
234 int id()
const {
return m_id; }
242 static auto &static_callbacks() {
244 std::pair<void (*)(), std::unique_ptr<detail::callback_concept_t>>>
252 std::shared_ptr<boost::mpi::environment> mpi_env)
253 : m_comm(std::move(
comm)), m_mpi_env(std::move(mpi_env)) {
255 m_callback_map.add(
nullptr);
257 for (
auto &kv : static_callbacks()) {
258 m_func_ptr_to_id[kv.first] = m_callback_map.add(kv.second.get());
264 if (m_comm.rank() == 0) {
284 template <
typename F>
auto add(F &&f) {
285 m_callbacks.emplace_back(detail::make_model(std::forward<F>(f)));
286 return m_callback_map.add(m_callbacks.back().get());
298 template <
class... Args>
void add(
void (*fp)(Args...)) {
299 m_callbacks.emplace_back(detail::make_model(fp));
300 const int id = m_callback_map.add(m_callbacks.back().get());
301 m_func_ptr_to_id[
reinterpret_cast<void (*)()
>(fp)] =
id;
312 template <
class... Args>
static void add_static(
void (*fp)(Args...)) {
313 static_callbacks().emplace_back(
reinterpret_cast<void (*)()
>(fp),
314 detail::make_model(fp));
326 void remove(
int id) {
327 std::erase_if(m_callbacks, [ptr = m_callback_map[
id]](
auto const &e) {
328 return e.get() == ptr;
330 m_callback_map.remove(
id);
345 template <
class... Args>
void call(
int id, Args &&...args)
const {
346 if (m_comm.rank() != 0) {
347 throw std::logic_error(
"Callbacks can only be invoked on rank 0.");
350 assert(m_callback_map.find(
id) != m_callback_map.end() &&
351 "m_callback_map and m_func_ptr_to_id disagree");
354 boost::mpi::packed_oarchive oa(m_comm);
358 std::apply([&oa](
auto &&...e) { ((oa << e), ...); },
359 std::forward_as_tuple(std::forward<Args>(args)...));
361 boost::mpi::broadcast(m_comm, oa, 0);
375 template <
class... Args,
class... ArgRef>
376 auto call(
void (*fp)(Args...), ArgRef &&...args) const ->
378 std::enable_if_t<std::is_void_v<decltype(fp(args...))>> {
379 const int id = m_func_ptr_to_id.at(
reinterpret_cast<void (*)()
>(fp));
381 call(
id, std::forward<ArgRef>(args)...);
394 template <
class... Args,
class... ArgRef>
395 auto call_all(
void (*fp)(Args...), ArgRef &&...args) const ->
397 std::enable_if_t<std::is_void_v<decltype(fp(args...))>> {
414 boost::mpi::packed_iarchive ia(m_comm);
415 boost::mpi::broadcast(m_comm, ia, 0);
420 if (request == LOOP_ABORT) {
424 m_callback_map[request]->operator()(m_comm, ia);
438 boost::mpi::communicator
const &
comm()
const {
return m_comm; }
448 static constexpr int LOOP_ABORT = 0;
453 boost::mpi::communicator m_comm;
458 std::shared_ptr<boost::mpi::environment> m_mpi_env;
463 std::vector<std::unique_ptr<detail::callback_concept_t>> m_callbacks;
474 std::unordered_map<void (*)(),
int> m_func_ptr_to_id;
477template <
class... Args>
504#define REGISTER_CALLBACK(cb) \
505 namespace Communication { \
506 static ::Communication::RegisterCallback register_##cb(&(cb)); \
Keep an enumerated list of T objects, managed by the class.
RAII handle for a callback.
auto operator()(ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(std::declval< void(*)(Args...)>()(std::forward< ArgRef >(args)...))> >
Call the callback managed by this handle.
CallbackHandle(CallbackHandle &&rhs) noexcept=default
CallbackHandle(CallbackHandle const &)=delete
CallbackHandle & operator=(CallbackHandle &&rhs) noexcept=default
CallbackHandle(std::shared_ptr< MpiCallbacks > cb, F &&f)
CallbackHandle & operator=(CallbackHandle const &)=delete
The interface of the MPI callback mechanism.
auto call_all(void(*fp)(Args...), ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(fp(args...))> >
Call a callback on all nodes.
auto call(void(*fp)(Args...), ArgRef &&...args) const -> std::enable_if_t< std::is_void_v< decltype(fp(args...))> >
Call a callback on worker nodes.
MpiCallbacks(boost::mpi::communicator comm, std::shared_ptr< boost::mpi::environment > mpi_env)
void add(void(*fp)(Args...))
Add a new callback.
std::shared_ptr< boost::mpi::environment > share_mpi_env() const
boost::mpi::communicator const & comm() const
The boost mpi communicator used by this instance.
void abort_loop()
Abort the MPI loop.
static void add_static(void(*fp)(Args...))
Add a new callback.
MpiCallbacks & operator=(MpiCallbacks const &)=delete
MpiCallbacks(MpiCallbacks const &)=delete
void loop() const
Start the MPI loop.
Helper class to add callbacks before main.
RegisterCallback()=delete
RegisterCallback(void(*cb)(Args...))
Container for objects that are identified by a numeric id.
static SteepestDescentParameters params
Currently active steepest descent instance.