ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
fft.cpp
Go to the documentation of this file.
1/*
2 * Copyright (C) 2010-2026 The ESPResSo project
3 * Copyright (C) 2002,2003,2004,2005,2006,2007,2008,2009,2010
4 * Max-Planck-Institute for Polymer Research, Theory Group
5 *
6 * This file is part of ESPResSo.
7 *
8 * ESPResSo is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * ESPResSo is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program. If not, see <http://www.gnu.org/licenses/>.
20 */
21/** \file
22 *
23 * Routines, row decomposition, data structures and communication for the
24 * 3D-FFT.
25 *
26 */
27
28#include "fft.hpp"
29#include "vector.hpp"
30
31#include "p3m/packing.hpp"
32
33#include <utils/Vector.hpp>
34#include <utils/index.hpp>
37
38#include <boost/mpi/communicator.hpp>
39#include <boost/serialization/vector.hpp>
40
41#include <fftw3.h>
42
43#include <omp.h>
44
45#include <algorithm>
46#include <cassert>
47#include <cmath>
48#include <cstddef>
49#include <limits>
50#include <optional>
51#include <span>
52#include <stdexcept>
53#include <utility>
54#include <vector>
55
57
58/** @name MPI tags for FFT communication */
59/**@{*/
60/** Tag for communication in forw_grid_comm() */
61#define REQ_FFT_FORW 301
62/** Tag for communication in back_grid_comm() */
63#define REQ_FFT_BACK 302
64/**@}*/
65
66template <typename T>
67static void fft_sendrecv(T const *const sendbuf, int scount, int dest,
68 T *const recvbuf, int rcount, int source,
69 boost::mpi::communicator const &comm, int tag) {
70 auto const type = boost::mpi::get_mpi_datatype<T>(*sendbuf);
71 MPI_Sendrecv(reinterpret_cast<void const *>(sendbuf), scount, type, dest, tag,
72 reinterpret_cast<void *>(recvbuf), rcount, type, source, tag,
73 comm, MPI_STATUS_IGNORE);
74}
75
76namespace fft {
77
78template <typename FloatType = double> struct fftw {
79 using complex = fftw_complex;
80 static auto constexpr plan_many_dft = fftw_plan_many_dft;
81 static auto constexpr destroy_plan = fftw_destroy_plan;
82 static auto constexpr execute_dft = fftw_execute_dft;
83 static auto constexpr malloc = fftw_malloc;
84 static auto constexpr free = fftw_free;
85};
86template <> struct fftw<float> {
87 using complex = fftwf_complex;
88 static auto constexpr plan_many_dft = fftwf_plan_many_dft;
89 static auto constexpr destroy_plan = fftwf_destroy_plan;
90 static auto constexpr execute_dft = fftwf_execute_dft;
91 static auto constexpr malloc = fftwf_malloc;
92 static auto constexpr free = fftwf_free;
93};
94
95/** This ugly function does the bookkeeping: which nodes have to
96 * communicate to each other, when you change the node grid.
97 * Changing the regular decomposition requires communication. This
98 * function finds (hopefully) the best way to do this. As input it
99 * needs the two grids (@p grid1, @p grid2) and a linear list (@p node_list1)
100 * with the node identities for @p grid1. The linear list (@p node_list2)
101 * for the second grid is calculated. For the communication group of
102 * the calling node it calculates a list (@c group) with the node
103 * identities and the positions (@p my_pos, @p pos) of that nodes in @p grid1
104 * and @p grid2. The return value is the size of the communication
105 * group. It gives -1 if the two grids do not fit to each other
106 * (@p grid1 and @p grid2 have to be component-wise multiples of each
107 * other, see e.g. \ref calc_2d_grid for how to do this).
108 *
109 * \param[in] grid1 The node grid you start with.
110 * \param[in] grid2 The node grid you want to have.
111 * \param[in] node_list1 Linear node index list for grid1.
112 * \param[out] node_list2 Linear node index list for grid2.
113 * \param[out] pos Positions of the nodes in grid2
114 * \param[out] my_pos Position of comm.rank() in grid2.
115 * \param[in] rank MPI rank.
116 * \return Size of the communication group.
117 */
118std::optional<std::vector<int>>
120 std::span<int const> node_list1, std::span<int> node_list2,
121 std::span<int> pos, std::span<int> my_pos, int rank) {
122 int i;
123 /* communication group cell size on grid1 and grid2 */
124 int s1[3], s2[3];
125 /* The communication group cells build the same super grid on grid1 and grid2
126 */
127 int ds[3];
128 /* communication group size */
129 int g_size = 1;
130 /* comm. group cell index */
131 int gi[3];
132 /* position of a node in a grid */
133 Utils::Vector3i p1, p2;
134 /* node identity */
135 int n;
136 /* comm.rank() position in the communication group. */
137 int c_pos = -1;
138 /* flag for group identification */
139 int my_group = 0;
140
141 /* calculate dimension of comm. group cells for both grids */
142 if (Utils::product(grid1) != Utils::product(grid2))
143 return std::nullopt; /* unlike number of nodes */
144 for (i = 0; i < 3; i++) {
145 s1[i] = grid1[i] / grid2[i];
146 if (s1[i] == 0)
147 s1[i] = 1;
148 else if (grid1[i] != grid2[i] * s1[i])
149 return std::nullopt; /* grids do not match!!! */
150
151 s2[i] = grid2[i] / grid1[i];
152 if (s2[i] == 0)
153 s2[i] = 1;
154 else if (grid2[i] != grid1[i] * s2[i])
155 return std::nullopt; /* grids do not match!!! */
156
157 ds[i] = grid2[i] / s2[i];
158 g_size *= s2[i];
159 }
160
161 std::vector<int> group(g_size);
162
163 /* calc node_list2 */
164 /* loop through all comm. group cells */
165 for (gi[2] = 0; gi[2] < ds[2]; gi[2]++)
166 for (gi[1] = 0; gi[1] < ds[1]; gi[1]++)
167 for (gi[0] = 0; gi[0] < ds[0]; gi[0]++) {
168 /* loop through all nodes in that comm. group cell */
169 for (i = 0; i < g_size; i++) {
170 p1[0] = (gi[0] * s1[0]) + (i % s1[0]);
171 p1[1] = (gi[1] * s1[1]) + ((i / s1[0]) % s1[1]);
172 p1[2] = (gi[2] * s1[2]) + (i / (s1[0] * s1[1]));
173
174 p2[0] = (gi[0] * s2[0]) + (i % s2[0]);
175 p2[1] = (gi[1] * s2[1]) + ((i / s2[0]) % s2[1]);
176 p2[2] = (gi[2] * s2[2]) + (i / (s2[0] * s2[1]));
177
178 n = node_list1[Utils::get_linear_index(p1, grid1)];
179 node_list2[Utils::get_linear_index(p2, grid2)] = n;
180
181 pos[3 * n + 0] = p2[0];
182 pos[3 * n + 1] = p2[1];
183 pos[3 * n + 2] = p2[2];
184 if (my_group == 1)
185 group[i] = n;
186 if (n == rank && my_group == 0) {
187 my_group = 1;
188 c_pos = i;
189 my_pos[0] = p2[0];
190 my_pos[1] = p2[1];
191 my_pos[2] = p2[2];
192 i = -1; /* restart the loop */
193 }
194 }
195 my_group = 0;
196 }
197
198 /* permute comm. group according to the nodes position in the group */
199 /* This is necessary to have matching node pairs during communication! */
200 while (c_pos > 0) {
201 n = group[g_size - 1];
202 for (i = g_size - 1; i > 0; i--)
203 group[i] = group[i - 1];
204 group[0] = n;
205 c_pos--;
206 }
207 return {group};
208}
209
210namespace {
211/** Calculate the local fft mesh. Calculate the local mesh (@p loc_mesh)
212 * of a node at position (@p n_pos) in a node grid (@p n_grid) for a global
213 * mesh of size (@p mesh) and a mesh offset (@p mesh_off (in mesh units))
214 * and store also the first point (@p start) of the local mesh.
215 *
216 * \param[in] n_pos Position of the node in @p n_grid.
217 * \param[in] n_grid node grid.
218 * \param[in] mesh global mesh dimensions.
219 * \param[in] mesh_off global mesh offset.
220 * \param[out] loc_mesh local mesh dimension.
221 * \param[out] start first point of local mesh in global mesh.
222 * \return Number of mesh points in local mesh.
223 */
224int calc_local_mesh(const int *n_pos, const int *n_grid, const int *mesh,
225 const double *mesh_off, int *loc_mesh, int *start) {
226 int last[3], size = 1;
227
228 for (int i = 0; i < 3; i++) {
229 auto const ai = mesh[i] / static_cast<double>(n_grid[i]);
230 start[i] = static_cast<int>(ceil(ai * n_pos[i] - mesh_off[i]));
231 last[i] = static_cast<int>(floor(ai * (n_pos[i] + 1) - mesh_off[i]));
232 /* correct round off errors */
233 if (ai * (n_pos[i] + 1) - mesh_off[i] - last[i] < 1.0e-15)
234 last[i]--;
235 if (1.0 + ai * n_pos[i] - mesh_off[i] - start[i] < 1.0e-15)
236 start[i]--;
237 loc_mesh[i] = last[i] - start[i] + 1;
238 size *= loc_mesh[i];
239 }
240 return size;
241}
242
243/** Calculate a send (or recv.) block for grid communication during a
244 * decomposition change. Calculate the send block specification
245 * (block = lower left corner and upper right corner) which a node at
246 * position (@p pos1) in the actual node grid (@p grid1) has to send to
247 * another node at position (@p pos2) in the desired node grid (@p grid2).
248 * The global mesh, subject to communication, is specified via its size
249 * (@p mesh) and its mesh offset (@p mesh_off (in mesh units)).
250 *
251 * For the calculation of a receive block you have to change the arguments in
252 * the following way:
253 * - @p pos1: position of receiving node in the desired node grid.
254 * - @p grid1: desired node grid.
255 * - @p pos2: position of the node you intend to receive the data from in the
256 * actual node grid.
257 * - @p grid2: actual node grid.
258 *
259 * \param[in] pos1 Position of send node in @p grid1.
260 * \param[in] grid1 node grid 1.
261 * \param[in] pos2 Position of recv node in @p grid2.
262 * \param[in] grid2 node grid 2.
263 * \param[in] mesh global mesh dimensions.
264 * \param[in] mesh_off global mesh offset.
265 * \param[out] block send block specification.
266 * \return Size of the send block.
267 */
268int calc_send_block(const int *pos1, const int *grid1, const int *pos2,
269 const int *grid2, const int *mesh, const double *mesh_off,
270 int *block) {
271 int size = 1;
272 int mesh1[3], first1[3], last1[3];
273 int mesh2[3], first2[3], last2[3];
274
275 calc_local_mesh(pos1, grid1, mesh, mesh_off, mesh1, first1);
276 calc_local_mesh(pos2, grid2, mesh, mesh_off, mesh2, first2);
277
278 for (int i = 0; i < 3; i++) {
279 last1[i] = first1[i] + mesh1[i] - 1;
280 last2[i] = first2[i] + mesh2[i] - 1;
281 block[i] = std::max(first1[i], first2[i]) - first1[i];
282 block[i + 3] = (std::min(last1[i], last2[i]) - first1[i]) - block[i] + 1;
283 size *= block[i + 3];
284 }
285 return size;
286}
287
288/** Pack a block with dimensions <tt>size[0] * size[1] * size[2]</tt> starting
289 * at @p start of an input 3D-grid with dimension @p dim into an output
290 * 3D-grid with dimensions <tt>size[2] * size[0] * size[1]</tt> with
291 * a simultaneous one-fold permutation of the indices. The permutation is
292 * defined as: slow_in -> fast_out, mid_in ->slow_out, fast_in -> mid_out.
293 *
294 * An element <tt>(i0_in, i1_in, i2_in)</tt> is then
295 * <tt>(i0_out = i1_in-start[1], i1_out = i2_in-start[2],
296 * i2_out = i0_in-start[0])</tt> and for the linear indices we have:
297 * - <tt>li_in = i2_in + size[2] * (i1_in + (size[1]*i0_in))</tt>
298 * - <tt>li_out = i2_out + size[0] * (i1_out + (size[2]*i0_out))</tt>
299 *
300 * For index definition see \ref fft_pack_block.
301 *
302 * \param[in] in input 3D-grid.
303 * \param[out] out output 3D-grid (block).
304 * \param[in] start start index of the block in the in-grid.
305 * \param[in] size size of the block (=dimension of the out-grid).
306 * \param[in] dim size of the in-grid.
307 * \param[in] element size of a grid element (e.g. 1 for Real, 2 for Complex).
308 */
309template <typename FloatType>
310void pack_block_permute1(FloatType const *const in, FloatType *const out,
311 const int *start, const int *size, const int *dim,
312 int element) {
313
314 /* offsets for indices in input grid */
315 auto const m_in_offset = element * (dim[2] - size[2]);
316 auto const s_in_offset = element * (dim[2] * (dim[1] - size[1]));
317 /* offset for mid changing indices of output grid */
318 auto const m_out_offset = (element * size[0]) - element;
319 /* linear index of in grid */
320 int li_in = element * (start[2] + dim[2] * (start[1] + dim[1] * start[0]));
321
322 for (int s = 0; s < size[0]; s++) { /* fast changing out */
323 /* linear index of out grid */
324 int li_out = element * s;
325 for (int m = 0; m < size[1]; m++) { /* slow changing out */
326 for (int f = 0; f < size[2]; f++) { /* mid changing out */
327 for (int e = 0; e < element; e++)
328 out[li_out++] = in[li_in++];
329 li_out += m_out_offset;
330 }
331 li_in += m_in_offset;
332 }
333 li_in += s_in_offset;
334 }
335}
336
337/** Pack a block with dimensions <tt>size[0] * size[1] * size[2]</tt> starting
338 * at @p start of an input 3D-grid with dimension @p dim into an output
339 * 3D-grid with dimensions <tt>size[2] * size[0] * size[1]</tt> with
340 * a simultaneous two-fold permutation of the indices. The permutation is
341 * defined as: slow_in -> mid_out, mid_in ->fast_out, fast_in -> slow_out.
342 *
343 * An element <tt>(i0_in, i1_in, i2_in)</tt> is then
344 * <tt>(i0_out = i2_in-start[2], i1_out = i0_in-start[0],
345 * i2_out = i1_in-start[1])</tt> and for the linear indices we have:
346 * - <tt>li_in = i2_in + size[2] * (i1_in + (size[1]*i0_in))</tt>
347 * - <tt>li_out = i2_out + size[0] * (i1_out + (size[2]*i0_out))</tt>
348 *
349 * For index definition see \ref fft_pack_block.
350 *
351 * \param[in] in input 3D-grid.
352 * \param[out] out output 3D-grid (block).
353 * \param[in] start start index of the block in the in-grid.
354 * \param[in] size size of the block (=dimension of the out-grid).
355 * \param[in] dim size of the in-grid.
356 * \param[in] element size of a grid element (e.g. 1 for Real, 2 for Complex).
357 */
358template <typename FloatType>
359void pack_block_permute2(FloatType const *const in, FloatType *const out,
360 const int *start, const int *size, const int *dim,
361 int element) {
362
363 /* offsets for indices in input grid */
364 auto const m_in_offset = element * (dim[2] - size[2]);
365 auto const s_in_offset = element * (dim[2] * (dim[1] - size[1]));
366 /* offset for slow changing index of output grid */
367 auto const s_out_offset = (element * size[0] * size[1]) - element;
368 /* linear index of in grid */
369 int li_in = element * (start[2] + dim[2] * (start[1] + dim[1] * start[0]));
370
371 for (int s = 0; s < size[0]; s++) { /* mid changing out */
372 auto const m_out_start = element * (s * size[1]);
373 for (int m = 0; m < size[1]; m++) { /* fast changing out */
374 /* linear index of out grid */
375 int li_out = m_out_start + element * m;
376 for (int f = 0; f < size[2]; f++) { /* slow changing out */
377 for (int e = 0; e < element; e++)
378 out[li_out++] = in[li_in++];
379 li_out += s_out_offset;
380 }
381 li_in += m_in_offset;
382 }
383 li_in += s_in_offset;
384 }
385}
386
387} // namespace
388
389/** Communicate the grid data according to the given forward FFT plan.
390 * \param comm MPI communicator.
391 * \param plan FFT communication plan.
392 * \param in input mesh.
393 * \param out output mesh.
394 */
395template <typename FloatType>
396void fft_data_struct<FloatType>::forw_grid_comm(
397 boost::mpi::communicator const &comm, fft_forw_plan<FloatType> const &plan,
398 FloatType const *in, FloatType *out) {
399 for (std::size_t i = 0ul; i < plan.group.size(); i++) {
400 plan.pack_function(in, send_buf.data(), &(plan.send_block[6ul * i]),
401 &(plan.send_block[6ul * i + 3ul]), plan.old_mesh.data(),
402 plan.element);
403
404 if (plan.group[i] != comm.rank()) {
405 fft_sendrecv(send_buf.data(), plan.send_size[i], plan.group[i],
406 recv_buf.data(), plan.recv_size[i], plan.group[i], comm,
408 } else { /* Self communication... */
409 std::swap(send_buf, recv_buf);
410 }
411 fft_unpack_block(recv_buf.data(), out, &(plan.recv_block[6ul * i]),
412 &(plan.recv_block[6ul * i + 3ul]), plan.new_mesh.data(),
413 plan.element);
414 }
415}
416
417/** Communicate the grid data according to the given backward FFT plan.
418 * \param comm MPI communicator.
419 * \param plan_f Forward FFT plan.
420 * \param plan_b Backward FFT plan.
421 * \param in input mesh.
422 * \param out output mesh.
423 */
424template <typename FloatType>
425void fft_data_struct<FloatType>::back_grid_comm(
426 boost::mpi::communicator const &comm,
427 fft_forw_plan<FloatType> const &plan_f,
428 fft_back_plan<FloatType> const &plan_b, FloatType const *in,
429 FloatType *out) {
430 /* Back means: Use the send/receive stuff from the forward plan but
431 replace the receive blocks by the send blocks and vice
432 versa. Attention then also new_mesh and old_mesh are exchanged */
433
434 for (std::size_t i = 0ul; i < plan_f.group.size(); i++) {
435 plan_b.pack_function(in, send_buf.data(), &(plan_f.recv_block[6ul * i]),
436 &(plan_f.recv_block[6ul * i + 3ul]),
437 plan_f.new_mesh.data(), plan_f.element);
438
439 if (plan_f.group[i] != comm.rank()) { /* send first, receive second */
440 fft_sendrecv(send_buf.data(), plan_f.recv_size[i], plan_f.group[i],
441 recv_buf.data(), plan_f.send_size[i], plan_f.group[i], comm,
443 } else { /* Self communication... */
444 std::swap(send_buf, recv_buf);
445 }
446 fft_unpack_block(recv_buf.data(), out, &(plan_f.send_block[6ul * i]),
447 &(plan_f.send_block[6ul * i + 3ul]),
448 plan_f.old_mesh.data(), plan_f.element);
449 }
450}
451
452/** Calculate 'best' mapping between a 2D and 3D grid.
453 * Required for the communication from 3D regular domain
454 * decomposition to 2D regular row decomposition.
455 * The dimensions of the 2D grid are resorted, if necessary, in a way
456 * that they are multiples of the 3D grid dimensions.
457 * \param g3d 3D grid.
458 * \param g2d 2D grid.
459 * \return index of the row direction [0,1,2].
460 */
461int map_3don2d_grid(int const g3d[3], int g2d[3]) {
462 int row_dir = -1;
463 /* trivial case */
464 if (g3d[2] == 1) {
465 return 2;
466 }
467 if (g2d[0] % g3d[0] == 0) {
468 if (g2d[1] % g3d[1] == 0) {
469 row_dir = 2;
470 } else if (g2d[1] % g3d[2] == 0) {
471 row_dir = 1;
472 g2d[2] = g2d[1];
473 g2d[1] = 1;
474 }
475 } else if (g2d[0] % g3d[1] == 0) {
476 if (g2d[1] % g3d[0] == 0) {
477 row_dir = 2;
478 int const tmp = g2d[0];
479 g2d[0] = g2d[1];
480 g2d[1] = tmp;
481 } else if (g2d[1] % g3d[2] == 0) {
482 row_dir = 0;
483 g2d[2] = g2d[1];
484 g2d[1] = g2d[0];
485 g2d[0] = 1;
486 }
487 } else if (g2d[0] % g3d[2] == 0) {
488 if (g2d[1] % g3d[0] == 0) {
489 row_dir = 1;
490 g2d[2] = g2d[0];
491 g2d[0] = g2d[1];
492 g2d[1] = 1;
493 } else if (g2d[1] % g3d[1] == 0) {
494 row_dir = 0;
495 g2d[2] = g2d[0];
496 g2d[0] = 1;
497 }
498 }
499 return row_dir;
500}
501
502/** Calculate most square 2D grid. */
503inline void calc_2d_grid(int n, int grid[3]) {
504 grid[0] = n;
505 grid[1] = 1;
506 grid[2] = 1;
507 for (auto i = static_cast<int>(std::sqrt(n)); i >= 1; i--) {
508 if (n % i == 0) {
509 grid[0] = n / i;
510 grid[1] = i;
511 grid[2] = 1;
512 return;
513 }
514 }
515}
516
517template <typename FloatType>
519 boost::mpi::communicator const &comm, Utils::Vector3i const &ca_mesh_dim,
520 int const *ca_mesh_margin, Utils::Vector3i const &global_mesh_dim,
521 Utils::Vector3d const &global_mesh_off, int &ks_pnum,
522 Utils::Vector3i const &grid) {
523
524 int n_grid[4][3]; /* The four node grids. */
525 int my_pos[4][3]; /* The position of comm.rank() in the node grids. */
526 std::vector<int> n_id[4]; /* linear node identity lists for the node grids. */
527 std::vector<int> n_pos[4]; /* positions of nodes in the node grids. */
528
529 auto const rank = comm.rank();
530 auto const node_pos = Utils::Mpi::cart_coords<3>(comm, rank);
531
532 max_comm_size = 0;
533 max_mesh_size = 0;
534 for (int i = 0; i < 4; i++) {
535 n_id[i].resize(1 * comm.size());
536 n_pos[i].resize(3 * comm.size());
537 }
538
539 /* === node grids === */
540 /* real space node grid (n_grid[0]) */
541 for (int i = 0; i < 3; i++) {
542 n_grid[0][i] = grid[i];
543 my_pos[0][i] = node_pos[i];
544 }
545 for (int i = 0; i < comm.size(); i++) {
546 auto const n_pos_i = Utils::Mpi::cart_coords<3>(comm, i);
547 for (int j = 0; j < 3; ++j) {
548 n_pos[0][3 * i + j] = n_pos_i[j];
549 }
550 auto const lin_ind = Utils::get_linear_index(
551 n_pos_i, {n_grid[0][0], n_grid[0][1], n_grid[0][2]});
552 n_id[0][lin_ind] = i;
553 }
554
555 /* FFT node grids (n_grid[1 - 3]) */
556 calc_2d_grid(comm.size(), n_grid[1]);
557 /* resort n_grid[1] dimensions if necessary */
558 forw[1].row_dir = map_3don2d_grid(n_grid[0], n_grid[1]);
559 forw[0].n_permute = 0;
560 for (int i = 1; i < 4; i++)
561 forw[i].n_permute = (forw[1].row_dir + i) % 3;
562 for (int i = 0; i < 3; i++) {
563 n_grid[2][i] = n_grid[1][(i + 1) % 3];
564 n_grid[3][i] = n_grid[1][(i + 2) % 3];
565 }
566 forw[2].row_dir = (forw[1].row_dir - 1) % 3;
567 forw[3].row_dir = (forw[1].row_dir - 2) % 3;
568
569 /* === communication groups === */
570 /* copy local mesh off real space charge assignment grid */
571 for (int i = 0; i < 3; i++)
572 forw[0].new_mesh[i] = ca_mesh_dim[i];
573
574 for (int i = 1; i < 4; i++) {
575 auto group = find_comm_groups(
576 {n_grid[i - 1][0], n_grid[i - 1][1], n_grid[i - 1][2]},
577 {n_grid[i][0], n_grid[i][1], n_grid[i][2]}, n_id[i - 1],
578 std::span(n_id[i]), std::span(n_pos[i]), my_pos[i], rank);
579 if (not group) {
580 /* try permutation */
581 std::swap(n_grid[i][(forw[i].row_dir + 1) % 3],
582 n_grid[i][(forw[i].row_dir + 2) % 3]);
583
584 group = find_comm_groups(
585 {n_grid[i - 1][0], n_grid[i - 1][1], n_grid[i - 1][2]},
586 {n_grid[i][0], n_grid[i][1], n_grid[i][2]}, std::span(n_id[i - 1]),
587 std::span(n_id[i]), std::span(n_pos[i]), my_pos[i], rank);
588
589 if (not group) {
590 throw std::runtime_error("INTERNAL ERROR: fft_find_comm_groups error");
591 }
592 }
593
594 forw[i].group = group.value();
595
596 forw[i].send_block.resize(6 * forw[i].group.size());
597 forw[i].send_size.resize(forw[i].group.size());
598 forw[i].recv_block.resize(6 * forw[i].group.size());
599 forw[i].recv_size.resize(forw[i].group.size());
600
601 forw[i].new_size = calc_local_mesh(
602 my_pos[i], n_grid[i], global_mesh_dim.data(), global_mesh_off.data(),
603 forw[i].new_mesh.data(), forw[i].start.data());
604 permute_ifield(forw[i].new_mesh.data(), 3, -(forw[i].n_permute));
605 permute_ifield(forw[i].start.data(), 3, -(forw[i].n_permute));
606 forw[i].n_ffts = forw[i].new_mesh[0] * forw[i].new_mesh[1];
607
608 /* === send/recv block specifications === */
609 for (std::size_t j = 0ul; j < forw[i].group.size(); j++) {
610 /* send block: comm.rank() to comm-group-node i (identity: node) */
611 int node = forw[i].group[j];
612 forw[i].send_size[j] = calc_send_block(
613 my_pos[i - 1], n_grid[i - 1], &(n_pos[i][3 * node]), n_grid[i],
614 global_mesh_dim.data(), global_mesh_off.data(),
615 &(forw[i].send_block[6ul * j]));
616 permute_ifield(&(forw[i].send_block[6ul * j]), 3,
617 -(forw[i - 1].n_permute));
618 permute_ifield(&(forw[i].send_block[6ul * j + 3ul]), 3,
619 -(forw[i - 1].n_permute));
620 if (forw[i].send_size[j] > max_comm_size)
621 max_comm_size = forw[i].send_size[j];
622 /* First plan send blocks have to be adjusted, since the CA grid
623 may have an additional margin outside the actual domain of the
624 node */
625 if (i == 1) {
626 for (std::size_t k = 0ul; k < 3ul; k++)
627 forw[1].send_block[6ul * j + k] += ca_mesh_margin[2ul * k];
628 }
629 /* recv block: comm.rank() from comm-group-node i (identity: node) */
630 forw[i].recv_size[j] = calc_send_block(
631 my_pos[i], n_grid[i], &(n_pos[i - 1][3 * node]), n_grid[i - 1],
632 global_mesh_dim.data(), global_mesh_off.data(),
633 &(forw[i].recv_block[6ul * j]));
634 permute_ifield(&(forw[i].recv_block[6ul * j]), 3, -(forw[i].n_permute));
635 permute_ifield(&(forw[i].recv_block[6ul * j + 3ul]), 3,
636 -(forw[i].n_permute));
637 if (forw[i].recv_size[j] > max_comm_size)
638 max_comm_size = forw[i].recv_size[j];
639 }
640
641 for (std::size_t j = 0ul; j < 3ul; j++)
642 forw[i].old_mesh[j] = forw[i - 1].new_mesh[j];
643 if (i == 1) {
644 forw[i].element = 1;
645 } else {
646 forw[i].element = 2;
647 for (std::size_t j = 0ul; j < forw[i].group.size(); j++) {
648 forw[i].send_size[j] *= 2;
649 forw[i].recv_size[j] *= 2;
650 }
651 }
652 }
653
654 /* Factor 2 for complex fields */
655 max_comm_size *= 2;
656 max_mesh_size = Utils::product(ca_mesh_dim);
657 for (int i = 1; i < 4; i++)
658 if (2 * forw[i].new_size > max_mesh_size)
659 max_mesh_size = 2 * forw[i].new_size;
660
661 /* === pack function === */
662 for (int i = 1; i < 4; i++) {
663 forw[i].pack_function = pack_block_permute2;
664 }
665 ks_pnum = 6;
666 if (forw[1].row_dir == 2) {
667 forw[1].pack_function = fft_pack_block;
668 ks_pnum = 4;
669 } else if (forw[1].row_dir == 1) {
670 forw[1].pack_function = pack_block_permute1;
671 ks_pnum = 5;
672 }
673
674 send_buf.resize(max_comm_size);
675 recv_buf.resize(max_comm_size);
676 data_buf.resize(max_mesh_size);
677 auto *c_data = (typename fftw<FloatType>::complex *)(data_buf.data());
678
679 /* === FFT Routines (Using FFTW / RFFTW package)=== */
680 for (int i = 1; i < 4; i++) {
681 if (init_tag) {
682#pragma omp critical(fftw_destroy_plan_forward)
683 forw[i].destroy_plan();
684 }
685 forw[i].dir = FFTW_FORWARD;
686#pragma omp critical(fftw_create_plan_forward)
687 forw[i].plan_handle = fftw<FloatType>::plan_many_dft(
688 1, &forw[i].new_mesh[2], forw[i].n_ffts, c_data, nullptr, 1,
689 forw[i].new_mesh[2], c_data, nullptr, 1, forw[i].new_mesh[2],
690 forw[i].dir, FFTW_PATIENT);
691 assert(forw[i].plan_handle);
692 }
693
694 /* === The BACK Direction === */
695 /* this is needed because slightly different functions are used */
696 for (int i = 1; i < 4; i++) {
697 if (init_tag) {
698#pragma omp critical(fftw_destroy_plan_backward)
699 back[i].destroy_plan();
700 }
701 back[i].dir = FFTW_BACKWARD;
702#pragma omp critical(fftw_create_plan_backward)
703 back[i].plan_handle = fftw<FloatType>::plan_many_dft(
704 1, &forw[i].new_mesh[2], forw[i].n_ffts, c_data, nullptr, 1,
705 forw[i].new_mesh[2], c_data, nullptr, 1, forw[i].new_mesh[2],
706 back[i].dir, FFTW_PATIENT);
707 back[i].pack_function = pack_block_permute1;
708 assert(back[i].plan_handle);
709 }
710 if (forw[1].row_dir == 2) {
711 back[1].pack_function = fft_pack_block;
712 } else if (forw[1].row_dir == 1) {
713 back[1].pack_function = pack_block_permute2;
714 }
715
716 init_tag = true;
717
718 return max_mesh_size;
719}
720
721template <typename FloatType>
723 boost::mpi::communicator const &comm, FloatType *data) {
724 /* ===== first direction ===== */
725
726 auto *c_data = (typename fftw<FloatType>::complex *)data;
727 auto *c_data_buf = (typename fftw<FloatType>::complex *)data_buf.data();
728
729 /* communication to current dir row format (in is data) */
730 forw_grid_comm(comm, forw[1], data, data_buf.data());
731
732 /* complexify the real data array (in is data_buf) */
733 for (int i = 0; i < forw[1].new_size; i++) {
734 data[2 * i + 0] = data_buf[i]; /* real value */
735 data[2 * i + 1] = FloatType(0); /* complex value */
736 }
737 /* perform FFT (in/out is data)*/
738 fftw<FloatType>::execute_dft(forw[1].plan_handle, c_data, c_data);
739 /* ===== second direction ===== */
740 /* communication to current dir row format (in is data) */
741 forw_grid_comm(comm, forw[2], data, data_buf.data());
742 /* perform FFT (in/out is data_buf) */
743 fftw<FloatType>::execute_dft(forw[2].plan_handle, c_data_buf, c_data_buf);
744 /* ===== third direction ===== */
745 /* communication to current dir row format (in is data_buf) */
746 forw_grid_comm(comm, forw[3], data_buf.data(), data);
747 /* perform FFT (in/out is data)*/
748 fftw<FloatType>::execute_dft(forw[3].plan_handle, c_data, c_data);
749
750 /* REMARK: Result has to be in data. */
751}
752
753template <typename FloatType>
755 boost::mpi::communicator const &comm, FloatType *data) {
756
757 auto *c_data = (typename fftw<FloatType>::complex *)data;
758 auto *c_data_buf = (typename fftw<FloatType>::complex *)data_buf.data();
759
760 /* ===== third direction ===== */
761
762 /* perform FFT (in is data) */
763 fftw<FloatType>::execute_dft(back[3].plan_handle, c_data, c_data);
764 /* communicate (in is data)*/
765 back_grid_comm(comm, forw[3], back[3], data, data_buf.data());
766
767 /* ===== second direction ===== */
768 /* perform FFT (in is data_buf) */
769 fftw<FloatType>::execute_dft(back[2].plan_handle, c_data_buf, c_data_buf);
770 /* communicate (in is data_buf) */
771 back_grid_comm(comm, forw[2], back[2], data_buf.data(), data);
772
773 /* ===== first direction ===== */
774 /* perform FFT (in is data) */
775 fftw<FloatType>::execute_dft(back[1].plan_handle, c_data, c_data);
776 /* throw away the (hopefully) empty complex component (in is data) */
777 for (int i = 0; i < forw[1].new_size; i++) {
778 data_buf[i] = data[2 * i]; /* real value */
779 }
780 /* communicate (in is data_buf) */
781 back_grid_comm(comm, forw[1], back[1], data_buf.data(), data);
782
783 /* REMARK: Result has to be in data. */
784}
785
786template <typename FloatType> void fft_plan<FloatType>::destroy_plan() {
787 if (plan_handle) {
789 plan_handle = nullptr;
790 }
791}
792
793template <class FloatType>
794FloatType *allocator<FloatType>::allocate(const std::size_t n) const {
795 if (n == 0) {
796 return nullptr;
797 }
798 if (n > std::numeric_limits<std::size_t>::max() / sizeof(FloatType)) {
799 throw std::bad_array_new_length();
800 }
801 void *const pv = fftw<FloatType>::malloc(n * sizeof(FloatType));
802 if (!pv) {
803 throw std::bad_alloc();
804 }
805 return static_cast<FloatType *>(pv);
806}
807
808template <class FloatType>
809void allocator<FloatType>::deallocate(FloatType *const p,
810 std::size_t) const noexcept {
811 fftw<FloatType>::free(static_cast<void *>(p));
812}
813
814template struct allocator<float>;
815template struct allocator<double>;
816
817template struct fft_plan<float>;
818template struct fft_plan<double>;
819
820template struct fft_forw_plan<float>;
821template struct fft_forw_plan<double>;
822
823template struct fft_back_plan<float>;
824template struct fft_back_plan<double>;
825
826template struct fft_data_struct<float>;
827template struct fft_data_struct<double>;
828
829} // namespace fft
Vector implementation and trait types for boost qvm interoperability.
DEVICE_QUALIFIER constexpr pointer data() noexcept
Definition Array.hpp:132
static double * block(double *p, std::size_t index, std::size_t size)
Definition elc.cpp:175
static void fft_sendrecv(T const *const sendbuf, int scount, int dest, T *const recvbuf, int rcount, int source, boost::mpi::communicator const &comm, int tag)
Definition fft.cpp:67
#define REQ_FFT_BACK
Tag for communication in back_grid_comm()
Definition fft.cpp:63
#define REQ_FFT_FORW
Tag for communication in forw_grid_comm()
Definition fft.cpp:61
Routines, row decomposition, data structures and communication for the 3D-FFT.
T product(Vector< T, N > const &v)
Definition Vector.hpp:372
void permute_ifield(int *field, int size, int permute)
permute an integer array field of size size about permute positions.
int get_linear_index(int a, int b, int c, const Vector3i &adim)
Definition index.hpp:35
void pack_block_permute2(FloatType const *const in, FloatType *const out, const int *start, const int *size, const int *dim, int element)
Pack a block with dimensions size[0] * size[1] * size[2] starting at start of an input 3D-grid with d...
Definition fft.cpp:359
void pack_block_permute1(FloatType const *const in, FloatType *const out, const int *start, const int *size, const int *dim, int element)
Pack a block with dimensions size[0] * size[1] * size[2] starting at start of an input 3D-grid with d...
Definition fft.cpp:310
int calc_send_block(const int *pos1, const int *grid1, const int *pos2, const int *grid2, const int *mesh, const double *mesh_off, int *block)
Calculate a send (or recv.) block for grid communication during a decomposition change.
Definition fft.cpp:268
int calc_local_mesh(const int *n_pos, const int *n_grid, const int *mesh, const double *mesh_off, int *loc_mesh, int *start)
Calculate the local fft mesh.
Definition fft.cpp:224
Definition fft.cpp:76
std::optional< std::vector< int > > find_comm_groups(Utils::Vector3i const &grid1, Utils::Vector3i const &grid2, std::span< int const > node_list1, std::span< int > node_list2, std::span< int > pos, std::span< int > my_pos, int rank)
This ugly function does the bookkeeping: which nodes have to communicate to each other,...
Definition fft.cpp:119
void calc_2d_grid(int n, int grid[3])
Calculate most square 2D grid.
Definition fft.cpp:503
int map_3don2d_grid(int const g3d[3], int g2d[3])
Calculate 'best' mapping between a 2D and 3D grid.
Definition fft.cpp:461
void fft_pack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Pack a 3D-block of size size starting at start of an input 3D-grid in with dimension dim into an outp...
Definition packing.hpp:44
void fft_unpack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Unpack a 3D-block in of size size into an output 3D-grid out of size dim starting at position start.
Definition packing.hpp:83
Aligned allocator for FFT data.
Definition vector.hpp:33
FloatType * allocate(std::size_t n) const
Definition fft.cpp:794
void deallocate(FloatType *p, std::size_t) const noexcept
Definition fft.cpp:809
Plan for a backward 1D FFT of a flattened 3D array.
Definition fft.hpp:116
Information about the three one dimensional FFTs and how the nodes have to communicate inbetween.
Definition fft.hpp:125
void forward_fft(boost::mpi::communicator const &comm, FloatType *data)
Perform an in-place forward 3D FFT.
Definition fft.cpp:722
void backward_fft(boost::mpi::communicator const &comm, FloatType *data)
Perform an in-place backward 3D FFT.
Definition fft.cpp:754
int initialize_fft(boost::mpi::communicator const &comm, Utils::Vector3i const &ca_mesh_dim, int const *ca_mesh_margin, Utils::Vector3i const &global_mesh_dim, Utils::Vector3d const &global_mesh_off, int &ks_pnum, Utils::Vector3i const &grid)
Initialize everything connected to the 3D-FFT.
Definition fft.cpp:518
Plan for a forward 1D FFT of a flattened 3D array.
Definition fft.hpp:82
void destroy_plan()
Definition fft.cpp:786
fftwf_complex complex
Definition fft.cpp:87
static auto constexpr malloc
Definition fft.cpp:83
static auto constexpr free
Definition fft.cpp:84
fftw_complex complex
Definition fft.cpp:79
static auto constexpr execute_dft
Definition fft.cpp:82
static auto constexpr destroy_plan
Definition fft.cpp:81
static auto constexpr plan_many_dft
Definition fft.cpp:80