39#include <boost/mpi/collectives/broadcast.hpp>
40#include <boost/mpi/communicator.hpp>
41#include <boost/mpi/environment.hpp>
42#include <boost/mpi/packed_iarchive.hpp>
62using is_allowed_argument =
63 std::integral_constant<bool,
64 not(std::is_pointer_v<T> ||
65 (!std::is_const_v<std::remove_reference_t<T>> &&
66 std::is_lvalue_reference_v<T>))>;
79template <
class F,
class... Args>
80auto invoke(F f, boost::mpi::packed_iarchive &ia) {
81 static_assert(std::conjunction_v<is_allowed_argument<Args>...>,
82 "Pointers and non-const references are not allowed as "
83 "arguments for callbacks.");
87 std::tuple<std::remove_const_t<std::remove_reference_t<Args>>...>
params;
88 std::apply([&ia](
auto &&...e) { ((ia >> e), ...); },
params);
94 return std::apply(f, std::as_const(
params));
104struct callback_concept_t {
110 virtual void operator()(boost::mpi::communicator
const &,
111 boost::mpi::packed_iarchive &)
const = 0;
112 virtual ~callback_concept_t() =
default;
121template <
class F,
class... Args>
122struct callback_void_t final :
public callback_concept_t {
125 callback_void_t(callback_void_t
const &) =
delete;
126 callback_void_t(callback_void_t &&) =
delete;
128 template <
class FRef>
129 explicit callback_void_t(FRef &&f) : m_f(
std::forward<FRef>(f)) {}
130 void operator()(boost::mpi::communicator
const &,
131 boost::mpi::packed_iarchive &ia)
const override {
132 detail::invoke<F, Args...>(m_f, ia);
136template <
class F,
class R,
class... Args>
struct FunctorTypes {
137 using functor_type = F;
138 using return_type = R;
139 using argument_types = std::tuple<Args...>;
142template <
class C,
class R,
class... Args>
143auto functor_types_impl(R (C::*)(Args...) const) {
144 return FunctorTypes<C, R, Args...>{};
149 decltype(functor_types_impl(&std::remove_reference_t<F>::operator()));
151template <
class CRef,
class C,
class R,
class... Args>
152auto make_model_impl(CRef &&c, FunctorTypes<C, R, Args...>) {
153 return std::make_unique<callback_void_t<C, Args...>>(std::forward<CRef>(c));
162template <
typename F>
auto make_model(F &&f) {
163 return make_model_impl(std::forward<F>(f), functor_types<F>{});
169template <
class... Args>
auto make_model(
void (*f_ptr)(Args...)) {
170 return std::make_unique<callback_void_t<void (*)(Args...), Args...>>(f_ptr);
192 template <
typename F>
193 requires(std::is_same_v<typename detail::functor_types<F>::argument_types,
194 std::tuple<Args...>>)
196 : m_id(cb->add(std::forward<F>(f))), m_cb(std::move(cb)) {}
205 std::shared_ptr<MpiCallbacks> m_cb;
215 template <
class... ArgRef>
219 requires(std::is_void_v<
decltype(std::declval<void (*)(Args...)>()(
220 std::forward<ArgRef>(args)...))>)
223 m_cb->call(m_id, std::forward<ArgRef>(args)...);
231 int id()
const {
return m_id; }
239 static auto &static_callbacks() {
241 std::pair<void (*)(), std::unique_ptr<detail::callback_concept_t>>>
249 std::shared_ptr<boost::mpi::environment> mpi_env)
250 : m_comm(
std::move(
comm)), m_mpi_env(
std::move(mpi_env)) {
252 m_callback_map.add(
nullptr);
254 for (
auto &[fp, handle] : static_callbacks()) {
255 m_func_ptr_to_id[fp] = m_callback_map.add(handle.get());
261 if (m_comm.rank() == 0) {
284 template <
typename F>
auto add(F &&f) {
285 m_callbacks.emplace_back(detail::make_model(std::forward<F>(f)));
286 return m_callback_map.add(m_callbacks.back().get());
298 template <
class... Args>
void add(
void (*fp)(Args...)) {
299 m_callbacks.emplace_back(detail::make_model(fp));
300 const int id = m_callback_map.add(m_callbacks.back().get());
301 m_func_ptr_to_id[
reinterpret_cast<void (*)()
>(fp)] =
id;
312 template <
class... Args>
static void add_static(
void (*fp)(Args...)) {
313 static_callbacks().emplace_back(
reinterpret_cast<void (*)()
>(fp),
314 detail::make_model(fp));
326 void remove(
int id) {
327 std::erase_if(m_callbacks, [ptr = m_callback_map[
id]](
auto const &e) {
328 return e.get() == ptr;
330 m_callback_map.remove(
id);
345 template <
class... Args>
void call(
int id, Args &&...args)
const {
346 if (m_comm.rank() != 0) {
347 throw std::logic_error(
"Callbacks can only be invoked on rank 0.");
350 assert(m_callback_map.find(
id) != m_callback_map.end() &&
351 "m_callback_map and m_func_ptr_to_id disagree");
354 boost::mpi::packed_oarchive oa(m_comm);
358 std::apply([&oa](
auto &&...e) { ((oa << e), ...); },
359 std::forward_as_tuple(std::forward<Args>(args)...));
361 boost::mpi::broadcast(m_comm, oa, 0);
375 template <
class... Args,
class... ArgRef>
376 auto call(
void (*fp)(Args...), ArgRef &&...args) const
378 requires(
std::is_void_v<decltype(fp(args...))>)
380 const int id = m_func_ptr_to_id.at(
reinterpret_cast<void (*)()
>(fp));
382 call(
id, std::forward<ArgRef>(args)...);
395 template <
class... Args,
class... ArgRef>
396 auto call_all(
void (*fp)(Args...), ArgRef &&...args) const
398 requires(
std::is_void_v<decltype(fp(args...))>)
416 boost::mpi::packed_iarchive ia(m_comm);
417 boost::mpi::broadcast(m_comm, ia, 0);
422 if (request == LOOP_ABORT) {
426 m_callback_map[request]->operator()(m_comm, ia);
440 boost::mpi::communicator
const &
comm()
const {
return m_comm; }
446 static constexpr int LOOP_ABORT = 0;
451 boost::mpi::communicator m_comm;
456 std::shared_ptr<boost::mpi::environment> m_mpi_env;
461 std::vector<std::unique_ptr<detail::callback_concept_t>> m_callbacks;
472 std::unordered_map<void (*)(),
int> m_func_ptr_to_id;
475template <
class... Args>
502#define REGISTER_CALLBACK(cb) \
503 namespace Communication { \
504 static ::Communication::RegisterCallback register_##cb(&(cb)); \
Keep an enumerated list of T objects, managed by the class.
RAII handle for a callback.
CallbackHandle(CallbackHandle &&rhs) noexcept=default
auto operator()(ArgRef &&...args) const
Call the callback managed by this handle.
CallbackHandle(CallbackHandle const &)=delete
CallbackHandle(std::shared_ptr< MpiCallbacks > cb, F &&f)
CallbackHandle & operator=(CallbackHandle &&rhs) noexcept=default
CallbackHandle & operator=(CallbackHandle const &)=delete
The interface of the MPI callback mechanism.
auto call_all(void(*fp)(Args...), ArgRef &&...args) const
Call a callback on all nodes.
MpiCallbacks(boost::mpi::communicator comm, std::shared_ptr< boost::mpi::environment > mpi_env)
void add(void(*fp)(Args...))
Add a new callback.
boost::mpi::communicator const & comm() const
The boost mpi communicator used by this instance.
void abort_loop()
Abort the MPI loop.
static void add_static(void(*fp)(Args...))
Add a new callback.
MpiCallbacks & operator=(MpiCallbacks const &)=delete
auto call(void(*fp)(Args...), ArgRef &&...args) const
Call a callback on worker nodes.
MpiCallbacks(MpiCallbacks const &)=delete
void loop() const
Start the MPI loop.
Helper class to add callbacks before main.
RegisterCallback()=delete
RegisterCallback(void(*cb)(Args...))
Container for objects that are identified by a numeric id.
static SteepestDescentParameters params
Currently active steepest descent instance.