24#if defined(P3M) or defined(DP3M)
34#include <boost/mpi/communicator.hpp>
35#include <boost/mpi/datatype.hpp>
47 T *
const recvbuf,
int rcount,
int source,
48 boost::mpi::communicator
const &comm,
int tag) {
49 auto const type = boost::mpi::get_mpi_datatype<T>(*sendbuf);
50 MPI_Sendrecv(
reinterpret_cast<void const *
>(sendbuf), scount, type, dest, tag,
51 reinterpret_cast<void *
>(recvbuf), rcount, type, source, tag,
52 comm, MPI_STATUS_IGNORE);
64template <
typename FloatType>
66 int const start[3],
int const size[3],
71 int li_in = 0, li_out = 0;
73 int m_out_offset, s_out_offset;
75 li_out = start[2] + (dim[2] * (start[1] + (dim[1] * start[0])));
76 m_out_offset = dim[2] - size[2];
77 s_out_offset = (dim[2] * (dim[1] - size[1]));
79 for (s = 0; s < size[0]; s++) {
80 for (m = 0; m < size[1]; m++) {
81 for (f = 0; f < size[2]; f++) {
82 out[li_out++] += in[li_in++];
84 li_out += m_out_offset;
86 li_out += s_out_offset;
90template <
typename FloatType>
93 int done[3] = {0, 0, 0};
95 for (
int i = 0; i < 3; i++) {
96 for (
int j = 0; j < 3; j++) {
98 s_ld[i * 2][j] = 0 + done[j] * local_mesh.
margin[j * 2];
100 s_ur[i * 2][j] = local_mesh.
margin[j * 2];
103 local_mesh.
dim[j] - done[j] * local_mesh.
margin[(j * 2) + 1];
106 s_ld[(i * 2) + 1][j] = local_mesh.
in_ur[j];
108 s_ld[(i * 2) + 1][j] = 0 + done[j] * local_mesh.
margin[j * 2];
109 s_ur[(i * 2) + 1][j] =
110 local_mesh.
dim[j] - done[j] * local_mesh.
margin[(j * 2) + 1];
115 for (
int i = 0; i < 6; i++) {
117 for (
int j = 0; j < 3; j++) {
118 s_dim[i][j] = s_ur[i][j] - s_ld[i][j];
119 s_size[i] *= s_dim[i][j];
121 max = std::max(max, s_size[i]);
124 auto const node_neighbors = Utils::Mpi::cart_neighbors<3>(comm);
127 for (
int i = 0; i < 6; i++) {
128 auto const j = (i % 2 == 0) ? i + 1 : i - 1;
130 if (node_neighbors[i] != comm.rank()) {
132 &(r_margin[j]), 1, node_neighbors[j], comm, REQ_P3M_INIT);
134 r_margin[j] = local_mesh.
margin[i];
138 for (
int i = 0; i < 3; i++) {
139 for (
int j = 0; j < 3; j++) {
141 r_ld[i * 2][j] = s_ld[i * 2][j] + local_mesh.
margin[2 * j];
142 r_ur[i * 2][j] = s_ur[i * 2][j] + r_margin[2 * j];
143 r_ld[(i * 2) + 1][j] = s_ld[(i * 2) + 1][j] - r_margin[(2 * j) + 1];
144 r_ur[(i * 2) + 1][j] =
145 s_ur[(i * 2) + 1][j] - local_mesh.
margin[(2 * j) + 1];
147 r_ld[i * 2][j] = s_ld[i * 2][j];
148 r_ur[i * 2][j] = s_ur[i * 2][j];
149 r_ld[(i * 2) + 1][j] = s_ld[(i * 2) + 1][j];
150 r_ur[(i * 2) + 1][j] = s_ur[(i * 2) + 1][j];
154 for (
int i = 0; i < 6; i++) {
156 for (
int j = 0; j < 3; j++) {
157 r_dim[i][j] = r_ur[i][j] - r_ld[i][j];
158 r_size[i] *= r_dim[i][j];
160 max = std::max(max, r_size[i]);
164template <
typename FloatType>
166 std::span<FloatType *> meshes,
168 auto const node_neighbors = Utils::Mpi::cart_neighbors<3>(comm);
169 send_grid.resize(max * meshes.size());
170 recv_grid.resize(max * meshes.size());
173 for (
int s_dir = 0; s_dir < 6; s_dir++) {
174 auto const r_dir = (s_dir % 2 == 0) ? s_dir + 1 : s_dir - 1;
177 if (s_size[s_dir] > 0) {
178 for (std::size_t i = 0; i < meshes.size(); i++) {
180 s_ld[s_dir], s_dim[s_dir], dim.
data(), 1);
185 if (node_neighbors[s_dir] != comm.rank()) {
186 auto const send_size =
static_cast<int>(meshes.size()) * s_size[s_dir];
187 auto const recv_size =
static_cast<int>(meshes.size()) * r_size[r_dir];
188 mesh_sendrecv(send_grid.data(), send_size, node_neighbors[s_dir],
189 recv_grid.data(), recv_size, node_neighbors[r_dir], comm,
192 std::swap(send_grid, recv_grid);
195 if (r_size[r_dir] > 0) {
196 for (std::size_t i = 0; i < meshes.size(); i++) {
197 p3m_add_block(recv_grid.data() + i * r_size[r_dir], meshes[i],
198 r_ld[r_dir], r_dim[r_dir], dim.
data());
204template <
typename FloatType>
206 std::span<FloatType *> meshes,
208 auto const node_neighbors = Utils::Mpi::cart_neighbors<3>(comm);
209 send_grid.resize(max * meshes.size());
210 recv_grid.resize(max * meshes.size());
213 for (
int s_dir = 5; s_dir >= 0; s_dir--) {
214 auto const r_dir = (s_dir % 2 == 0) ? s_dir + 1 : s_dir - 1;
217 if (r_size[r_dir] > 0) {
218 for (std::size_t i = 0; i < meshes.size(); i++) {
220 r_ld[r_dir], r_dim[r_dir], dim.
data(), 1);
224 if (node_neighbors[r_dir] != comm.rank()) {
225 auto const send_size =
static_cast<int>(meshes.size()) * r_size[r_dir];
226 auto const recv_size =
static_cast<int>(meshes.size()) * s_size[s_dir];
227 mesh_sendrecv(send_grid.data(), send_size, node_neighbors[r_dir],
228 recv_grid.data(), recv_size, node_neighbors[s_dir], comm,
231 std::swap(send_grid, recv_grid);
234 if (s_size[s_dir] > 0) {
235 for (std::size_t i = 0; i < meshes.size(); i++) {
237 s_ld[s_dir], s_dim[s_dir], dim.
data(), 1);
Vector implementation and trait types for boost qvm interoperability.
void gather_grid(boost::mpi::communicator const &comm, std::span< FloatType * > meshes, Utils::Vector3i const &dim)
void spread_grid(boost::mpi::communicator const &comm, std::span< FloatType * > meshes, Utils::Vector3i const &dim)
void resize(boost::mpi::communicator const &comm, P3MLocalMesh const &local_mesh)
This file contains the defaults for ESPResSo.
Routines, row decomposition, data structures and communication for the 3D-FFT.
Common functions for dipolar and charge P3M.
void fft_pack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Pack a 3D-block of size size starting at start of an input 3D-grid in with dimension dim into an outp...
void fft_unpack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Unpack a 3D-block in of size size into an output 3D-grid out of size dim starting at position start.
static void p3m_add_block(FloatType const *in, FloatType *out, int const start[3], int const size[3], int const dim[3])
Add values of a 3d-grid input block (size[3]) to values of 3d-grid output array with dimension dim[3]...
static void mesh_sendrecv(T const *const sendbuf, int scount, int dest, T *const recvbuf, int rcount, int source, boost::mpi::communicator const &comm, int tag)
Properties of the local mesh.
Utils::Vector3i dim
dimension (size) of local mesh including halo layers.
int margin[6]
number of margin mesh points.
Utils::Vector3i in_ur
inner up right grid point + (1,1,1)
DEVICE_QUALIFIER constexpr pointer data() noexcept