24#if defined(P3M) or defined(DP3M)
34#include <boost/mpi/communicator.hpp>
35#include <boost/mpi/datatype.hpp>
46 T *
const recvbuf,
int rcount,
int source,
47 boost::mpi::communicator
const &comm,
int tag) {
48 auto const type = boost::mpi::get_mpi_datatype<T>(*sendbuf);
49 MPI_Sendrecv(
reinterpret_cast<void const *
>(sendbuf), scount, type, dest, tag,
50 reinterpret_cast<void *
>(recvbuf), rcount, type, source, tag,
51 comm, MPI_STATUS_IGNORE);
63template <
typename FloatType>
65 int const start[3],
int const size[3],
70 int li_in = 0, li_out = 0;
72 int m_out_offset, s_out_offset;
74 li_out = start[2] + (dim[2] * (start[1] + (dim[1] * start[0])));
75 m_out_offset = dim[2] - size[2];
76 s_out_offset = (dim[2] * (dim[1] - size[1]));
78 for (s = 0; s < size[0]; s++) {
79 for (m = 0; m < size[1]; m++) {
80 for (f = 0; f < size[2]; f++) {
81 out[li_out++] += in[li_in++];
83 li_out += m_out_offset;
85 li_out += s_out_offset;
89template <
typename FloatType>
92 int done[3] = {0, 0, 0};
94 for (
int i = 0; i < 3; i++) {
95 for (
int j = 0; j < 3; j++) {
97 s_ld[i * 2][j] = 0 + done[j] * local_mesh.
margin[j * 2];
99 s_ur[i * 2][j] = local_mesh.
margin[j * 2];
102 local_mesh.
dim[j] - done[j] * local_mesh.
margin[(j * 2) + 1];
105 s_ld[(i * 2) + 1][j] = local_mesh.
in_ur[j];
107 s_ld[(i * 2) + 1][j] = 0 + done[j] * local_mesh.
margin[j * 2];
108 s_ur[(i * 2) + 1][j] =
109 local_mesh.
dim[j] - done[j] * local_mesh.
margin[(j * 2) + 1];
114 for (
int i = 0; i < 6; i++) {
116 for (
int j = 0; j < 3; j++) {
117 s_dim[i][j] = s_ur[i][j] - s_ld[i][j];
118 s_size[i] *= s_dim[i][j];
120 max = std::max(max, s_size[i]);
123 auto const node_neighbors = Utils::Mpi::cart_neighbors<3>(comm);
126 for (
int i = 0; i < 6; i++) {
127 auto const j = (i % 2 == 0) ? i + 1 : i - 1;
129 if (node_neighbors[i] != comm.rank()) {
131 &(r_margin[j]), 1, node_neighbors[j], comm, REQ_P3M_INIT);
133 r_margin[j] = local_mesh.
margin[i];
137 for (
int i = 0; i < 3; i++) {
138 for (
int j = 0; j < 3; j++) {
140 r_ld[i * 2][j] = s_ld[i * 2][j] + local_mesh.
margin[2 * j];
141 r_ur[i * 2][j] = s_ur[i * 2][j] + r_margin[2 * j];
142 r_ld[(i * 2) + 1][j] = s_ld[(i * 2) + 1][j] - r_margin[(2 * j) + 1];
143 r_ur[(i * 2) + 1][j] =
144 s_ur[(i * 2) + 1][j] - local_mesh.
margin[(2 * j) + 1];
146 r_ld[i * 2][j] = s_ld[i * 2][j];
147 r_ur[i * 2][j] = s_ur[i * 2][j];
148 r_ld[(i * 2) + 1][j] = s_ld[(i * 2) + 1][j];
149 r_ur[(i * 2) + 1][j] = s_ur[(i * 2) + 1][j];
153 for (
int i = 0; i < 6; i++) {
155 for (
int j = 0; j < 3; j++) {
156 r_dim[i][j] = r_ur[i][j] - r_ld[i][j];
157 r_size[i] *= r_dim[i][j];
159 max = std::max(max, r_size[i]);
163template <
typename FloatType>
165 std::span<FloatType *> meshes,
167 auto const node_neighbors = Utils::Mpi::cart_neighbors<3>(comm);
168 send_grid.resize(max * meshes.size());
169 recv_grid.resize(max * meshes.size());
172 for (
int s_dir = 0; s_dir < 6; s_dir++) {
173 auto const r_dir = (s_dir % 2 == 0) ? s_dir + 1 : s_dir - 1;
176 if (s_size[s_dir] > 0) {
177 for (std::size_t i = 0; i < meshes.size(); i++) {
179 s_ld[s_dir], s_dim[s_dir], dim.
data(), 1);
184 if (node_neighbors[s_dir] != comm.rank()) {
185 auto const send_size =
static_cast<int>(meshes.size()) * s_size[s_dir];
186 auto const recv_size =
static_cast<int>(meshes.size()) * r_size[r_dir];
187 mesh_sendrecv(send_grid.data(), send_size, node_neighbors[s_dir],
188 recv_grid.data(), recv_size, node_neighbors[r_dir], comm,
191 std::swap(send_grid, recv_grid);
194 if (r_size[r_dir] > 0) {
195 for (std::size_t i = 0; i < meshes.size(); i++) {
196 p3m_add_block(recv_grid.data() + i * r_size[r_dir], meshes[i],
197 r_ld[r_dir], r_dim[r_dir], dim.
data());
203template <
typename FloatType>
205 std::span<FloatType *> meshes,
207 auto const node_neighbors = Utils::Mpi::cart_neighbors<3>(comm);
208 send_grid.resize(max * meshes.size());
209 recv_grid.resize(max * meshes.size());
212 for (
int s_dir = 5; s_dir >= 0; s_dir--) {
213 auto const r_dir = (s_dir % 2 == 0) ? s_dir + 1 : s_dir - 1;
216 if (r_size[r_dir] > 0) {
217 for (std::size_t i = 0; i < meshes.size(); i++) {
219 r_ld[r_dir], r_dim[r_dir], dim.
data(), 1);
223 if (node_neighbors[r_dir] != comm.rank()) {
224 auto const send_size =
static_cast<int>(meshes.size()) * r_size[r_dir];
225 auto const recv_size =
static_cast<int>(meshes.size()) * s_size[s_dir];
226 mesh_sendrecv(send_grid.data(), send_size, node_neighbors[r_dir],
227 recv_grid.data(), recv_size, node_neighbors[s_dir], comm,
230 std::swap(send_grid, recv_grid);
233 if (s_size[s_dir] > 0) {
234 for (std::size_t i = 0; i < meshes.size(); i++) {
236 s_ld[s_dir], s_dim[s_dir], dim.
data(), 1);
Vector implementation and trait types for boost qvm interoperability.
void gather_grid(boost::mpi::communicator const &comm, std::span< FloatType * > meshes, Utils::Vector3i const &dim)
void spread_grid(boost::mpi::communicator const &comm, std::span< FloatType * > meshes, Utils::Vector3i const &dim)
void resize(boost::mpi::communicator const &comm, P3MLocalMesh const &local_mesh)
This file contains the defaults for ESPResSo.
Routines, row decomposition, data structures and communication for the 3D-FFT.
Common functions for dipolar and charge P3M.
void fft_pack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Pack a 3D-block of size size starting at start of an input 3D-grid in with dimension dim into an outp...
void fft_unpack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Unpack a 3D-block in of size size into an output 3D-grid out of size dim starting at position start.
static void p3m_add_block(FloatType const *in, FloatType *out, int const start[3], int const size[3], int const dim[3])
Add values of a 3d-grid input block (size[3]) to values of 3d-grid output array with dimension dim[3]...
static void mesh_sendrecv(T const *const sendbuf, int scount, int dest, T *const recvbuf, int rcount, int source, boost::mpi::communicator const &comm, int tag)
Properties of the local mesh.
Utils::Vector3i dim
dimension (size) of local mesh.
int in_ur[3]
inner up right grid point + (1,1,1)
int margin[6]
number of margin mesh points.
DEVICE_QUALIFIER constexpr pointer data() noexcept