ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
fft.cpp
Go to the documentation of this file.
1/*
2 * Copyright (C) 2010-2024 The ESPResSo project
3 * Copyright (C) 2002,2003,2004,2005,2006,2007,2008,2009,2010
4 * Max-Planck-Institute for Polymer Research, Theory Group
5 *
6 * This file is part of ESPResSo.
7 *
8 * ESPResSo is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * ESPResSo is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program. If not, see <http://www.gnu.org/licenses/>.
20 */
21/** \file
22 *
23 * Routines, row decomposition, data structures and communication for the
24 * 3D-FFT.
25 *
26 */
27
28#include "fft.hpp"
29#include "vector.hpp"
30
31#include "p3m/packing.hpp"
32
33#include <utils/Vector.hpp>
34#include <utils/index.hpp>
37
38#include <boost/mpi/communicator.hpp>
39#include <boost/serialization/vector.hpp>
40
41#include <fftw3.h>
42
43#include <algorithm>
44#include <cmath>
45#include <cstddef>
46#include <cstdio>
47#include <cstring>
48#include <optional>
49#include <span>
50#include <stdexcept>
51#include <utility>
52#include <vector>
53
55
56/** @name MPI tags for FFT communication */
57/**@{*/
58/** Tag for communication in forw_grid_comm() */
59#define REQ_FFT_FORW 301
60/** Tag for communication in back_grid_comm() */
61#define REQ_FFT_BACK 302
62/**@}*/
63
64template <typename T>
65static void fft_sendrecv(T const *const sendbuf, int scount, int dest,
66 T *const recvbuf, int rcount, int source,
67 boost::mpi::communicator const &comm, int tag) {
68 auto const type = boost::mpi::get_mpi_datatype<T>(*sendbuf);
69 MPI_Sendrecv(reinterpret_cast<void const *>(sendbuf), scount, type, dest, tag,
70 reinterpret_cast<void *>(recvbuf), rcount, type, source, tag,
71 comm, MPI_STATUS_IGNORE);
72}
73
74namespace fft {
75
76template <typename FloatType = double> struct fftw {
77 using complex = fftw_complex;
78 static auto constexpr plan_many_dft = fftw_plan_many_dft;
79 static auto constexpr destroy_plan = fftw_destroy_plan;
80 static auto constexpr execute_dft = fftw_execute_dft;
81 static auto constexpr malloc = fftw_malloc;
82 static auto constexpr free = fftw_free;
83};
84template <> struct fftw<float> {
85 using complex = fftwf_complex;
86 static auto constexpr plan_many_dft = fftwf_plan_many_dft;
87 static auto constexpr destroy_plan = fftwf_destroy_plan;
88 static auto constexpr execute_dft = fftwf_execute_dft;
89 static auto constexpr malloc = fftwf_malloc;
90 static auto constexpr free = fftwf_free;
91};
92
93/** This ugly function does the bookkeeping: which nodes have to
94 * communicate to each other, when you change the node grid.
95 * Changing the regular decomposition requires communication. This
96 * function finds (hopefully) the best way to do this. As input it
97 * needs the two grids (@p grid1, @p grid2) and a linear list (@p node_list1)
98 * with the node identities for @p grid1. The linear list (@p node_list2)
99 * for the second grid is calculated. For the communication group of
100 * the calling node it calculates a list (@c group) with the node
101 * identities and the positions (@p my_pos, @p pos) of that nodes in @p grid1
102 * and @p grid2. The return value is the size of the communication
103 * group. It gives -1 if the two grids do not fit to each other
104 * (@p grid1 and @p grid2 have to be component-wise multiples of each
105 * other, see e.g. \ref calc_2d_grid for how to do this).
106 *
107 * \param[in] grid1 The node grid you start with.
108 * \param[in] grid2 The node grid you want to have.
109 * \param[in] node_list1 Linear node index list for grid1.
110 * \param[out] node_list2 Linear node index list for grid2.
111 * \param[out] pos Positions of the nodes in grid2
112 * \param[out] my_pos Position of comm.rank() in grid2.
113 * \param[in] rank MPI rank.
114 * \return Size of the communication group.
115 */
116std::optional<std::vector<int>>
118 std::span<int const> node_list1, std::span<int> node_list2,
119 std::span<int> pos, std::span<int> my_pos, int rank) {
120 int i;
121 /* communication group cell size on grid1 and grid2 */
122 int s1[3], s2[3];
123 /* The communication group cells build the same super grid on grid1 and grid2
124 */
125 int ds[3];
126 /* communication group size */
127 int g_size = 1;
128 /* comm. group cell index */
129 int gi[3];
130 /* position of a node in a grid */
131 Utils::Vector3i p1, p2;
132 /* node identity */
133 int n;
134 /* comm.rank() position in the communication group. */
135 int c_pos = -1;
136 /* flag for group identification */
137 int my_group = 0;
138
139 /* calculate dimension of comm. group cells for both grids */
140 if (Utils::product(grid1) != Utils::product(grid2))
141 return std::nullopt; /* unlike number of nodes */
142 for (i = 0; i < 3; i++) {
143 s1[i] = grid1[i] / grid2[i];
144 if (s1[i] == 0)
145 s1[i] = 1;
146 else if (grid1[i] != grid2[i] * s1[i])
147 return std::nullopt; /* grids do not match!!! */
148
149 s2[i] = grid2[i] / grid1[i];
150 if (s2[i] == 0)
151 s2[i] = 1;
152 else if (grid2[i] != grid1[i] * s2[i])
153 return std::nullopt; /* grids do not match!!! */
154
155 ds[i] = grid2[i] / s2[i];
156 g_size *= s2[i];
157 }
158
159 std::vector<int> group(g_size);
160
161 /* calc node_list2 */
162 /* loop through all comm. group cells */
163 for (gi[2] = 0; gi[2] < ds[2]; gi[2]++)
164 for (gi[1] = 0; gi[1] < ds[1]; gi[1]++)
165 for (gi[0] = 0; gi[0] < ds[0]; gi[0]++) {
166 /* loop through all nodes in that comm. group cell */
167 for (i = 0; i < g_size; i++) {
168 p1[0] = (gi[0] * s1[0]) + (i % s1[0]);
169 p1[1] = (gi[1] * s1[1]) + ((i / s1[0]) % s1[1]);
170 p1[2] = (gi[2] * s1[2]) + (i / (s1[0] * s1[1]));
171
172 p2[0] = (gi[0] * s2[0]) + (i % s2[0]);
173 p2[1] = (gi[1] * s2[1]) + ((i / s2[0]) % s2[1]);
174 p2[2] = (gi[2] * s2[2]) + (i / (s2[0] * s2[1]));
175
176 n = node_list1[Utils::get_linear_index(p1, grid1)];
177 node_list2[Utils::get_linear_index(p2, grid2)] = n;
178
179 pos[3 * n + 0] = p2[0];
180 pos[3 * n + 1] = p2[1];
181 pos[3 * n + 2] = p2[2];
182 if (my_group == 1)
183 group[i] = n;
184 if (n == rank && my_group == 0) {
185 my_group = 1;
186 c_pos = i;
187 my_pos[0] = p2[0];
188 my_pos[1] = p2[1];
189 my_pos[2] = p2[2];
190 i = -1; /* restart the loop */
191 }
192 }
193 my_group = 0;
194 }
195
196 /* permute comm. group according to the nodes position in the group */
197 /* This is necessary to have matching node pairs during communication! */
198 while (c_pos > 0) {
199 n = group[g_size - 1];
200 for (i = g_size - 1; i > 0; i--)
201 group[i] = group[i - 1];
202 group[0] = n;
203 c_pos--;
204 }
205 return {group};
206}
207
208namespace {
209/** Calculate the local fft mesh. Calculate the local mesh (@p loc_mesh)
210 * of a node at position (@p n_pos) in a node grid (@p n_grid) for a global
211 * mesh of size (@p mesh) and a mesh offset (@p mesh_off (in mesh units))
212 * and store also the first point (@p start) of the local mesh.
213 *
214 * \param[in] n_pos Position of the node in @p n_grid.
215 * \param[in] n_grid node grid.
216 * \param[in] mesh global mesh dimensions.
217 * \param[in] mesh_off global mesh offset.
218 * \param[out] loc_mesh local mesh dimension.
219 * \param[out] start first point of local mesh in global mesh.
220 * \return Number of mesh points in local mesh.
221 */
222int calc_local_mesh(const int *n_pos, const int *n_grid, const int *mesh,
223 const double *mesh_off, int *loc_mesh, int *start) {
224 int last[3], size = 1;
225
226 for (int i = 0; i < 3; i++) {
227 auto const ai = mesh[i] / static_cast<double>(n_grid[i]);
228 start[i] = static_cast<int>(ceil(ai * n_pos[i] - mesh_off[i]));
229 last[i] = static_cast<int>(floor(ai * (n_pos[i] + 1) - mesh_off[i]));
230 /* correct round off errors */
231 if (ai * (n_pos[i] + 1) - mesh_off[i] - last[i] < 1.0e-15)
232 last[i]--;
233 if (1.0 + ai * n_pos[i] - mesh_off[i] - start[i] < 1.0e-15)
234 start[i]--;
235 loc_mesh[i] = last[i] - start[i] + 1;
236 size *= loc_mesh[i];
237 }
238 return size;
239}
240
241/** Calculate a send (or recv.) block for grid communication during a
242 * decomposition change. Calculate the send block specification
243 * (block = lower left corner and upper right corner) which a node at
244 * position (@p pos1) in the actual node grid (@p grid1) has to send to
245 * another node at position (@p pos2) in the desired node grid (@p grid2).
246 * The global mesh, subject to communication, is specified via its size
247 * (@p mesh) and its mesh offset (@p mesh_off (in mesh units)).
248 *
249 * For the calculation of a receive block you have to change the arguments in
250 * the following way:
251 * - @p pos1: position of receiving node in the desired node grid.
252 * - @p grid1: desired node grid.
253 * - @p pos2: position of the node you intend to receive the data from in the
254 * actual node grid.
255 * - @p grid2: actual node grid.
256 *
257 * \param[in] pos1 Position of send node in @p grid1.
258 * \param[in] grid1 node grid 1.
259 * \param[in] pos2 Position of recv node in @p grid2.
260 * \param[in] grid2 node grid 2.
261 * \param[in] mesh global mesh dimensions.
262 * \param[in] mesh_off global mesh offset.
263 * \param[out] block send block specification.
264 * \return Size of the send block.
265 */
266int calc_send_block(const int *pos1, const int *grid1, const int *pos2,
267 const int *grid2, const int *mesh, const double *mesh_off,
268 int *block) {
269 int size = 1;
270 int mesh1[3], first1[3], last1[3];
271 int mesh2[3], first2[3], last2[3];
272
273 calc_local_mesh(pos1, grid1, mesh, mesh_off, mesh1, first1);
274 calc_local_mesh(pos2, grid2, mesh, mesh_off, mesh2, first2);
275
276 for (int i = 0; i < 3; i++) {
277 last1[i] = first1[i] + mesh1[i] - 1;
278 last2[i] = first2[i] + mesh2[i] - 1;
279 block[i] = std::max(first1[i], first2[i]) - first1[i];
280 block[i + 3] = (std::min(last1[i], last2[i]) - first1[i]) - block[i] + 1;
281 size *= block[i + 3];
282 }
283 return size;
284}
285
286/** Pack a block with dimensions <tt>size[0] * size[1] * size[2]</tt> starting
287 * at @p start of an input 3D-grid with dimension @p dim into an output
288 * 3D-grid with dimensions <tt>size[2] * size[0] * size[1]</tt> with
289 * a simultaneous one-fold permutation of the indices. The permutation is
290 * defined as: slow_in -> fast_out, mid_in ->slow_out, fast_in -> mid_out.
291 *
292 * An element <tt>(i0_in, i1_in, i2_in)</tt> is then
293 * <tt>(i0_out = i1_in-start[1], i1_out = i2_in-start[2],
294 * i2_out = i0_in-start[0])</tt> and for the linear indices we have:
295 * - <tt>li_in = i2_in + size[2] * (i1_in + (size[1]*i0_in))</tt>
296 * - <tt>li_out = i2_out + size[0] * (i1_out + (size[2]*i0_out))</tt>
297 *
298 * For index definition see \ref fft_pack_block.
299 *
300 * \param[in] in input 3D-grid.
301 * \param[out] out output 3D-grid (block).
302 * \param[in] start start index of the block in the in-grid.
303 * \param[in] size size of the block (=dimension of the out-grid).
304 * \param[in] dim size of the in-grid.
305 * \param[in] element size of a grid element (e.g. 1 for Real, 2 for Complex).
306 */
307template <typename FloatType>
308void pack_block_permute1(FloatType const *const in, FloatType *const out,
309 const int *start, const int *size, const int *dim,
310 int element) {
311
312 /* offsets for indices in input grid */
313 auto const m_in_offset = element * (dim[2] - size[2]);
314 auto const s_in_offset = element * (dim[2] * (dim[1] - size[1]));
315 /* offset for mid changing indices of output grid */
316 auto const m_out_offset = (element * size[0]) - element;
317 /* linear index of in grid */
318 int li_in = element * (start[2] + dim[2] * (start[1] + dim[1] * start[0]));
319
320 for (int s = 0; s < size[0]; s++) { /* fast changing out */
321 /* linear index of out grid */
322 int li_out = element * s;
323 for (int m = 0; m < size[1]; m++) { /* slow changing out */
324 for (int f = 0; f < size[2]; f++) { /* mid changing out */
325 for (int e = 0; e < element; e++)
326 out[li_out++] = in[li_in++];
327 li_out += m_out_offset;
328 }
329 li_in += m_in_offset;
330 }
331 li_in += s_in_offset;
332 }
333}
334
335/** Pack a block with dimensions <tt>size[0] * size[1] * size[2]</tt> starting
336 * at @p start of an input 3D-grid with dimension @p dim into an output
337 * 3D-grid with dimensions <tt>size[2] * size[0] * size[1]</tt> with
338 * a simultaneous two-fold permutation of the indices. The permutation is
339 * defined as: slow_in -> mid_out, mid_in ->fast_out, fast_in -> slow_out.
340 *
341 * An element <tt>(i0_in, i1_in, i2_in)</tt> is then
342 * <tt>(i0_out = i2_in-start[2], i1_out = i0_in-start[0],
343 * i2_out = i1_in-start[1])</tt> and for the linear indices we have:
344 * - <tt>li_in = i2_in + size[2] * (i1_in + (size[1]*i0_in))</tt>
345 * - <tt>li_out = i2_out + size[0] * (i1_out + (size[2]*i0_out))</tt>
346 *
347 * For index definition see \ref fft_pack_block.
348 *
349 * \param[in] in input 3D-grid.
350 * \param[out] out output 3D-grid (block).
351 * \param[in] start start index of the block in the in-grid.
352 * \param[in] size size of the block (=dimension of the out-grid).
353 * \param[in] dim size of the in-grid.
354 * \param[in] element size of a grid element (e.g. 1 for Real, 2 for Complex).
355 */
356template <typename FloatType>
357void pack_block_permute2(FloatType const *const in, FloatType *const out,
358 const int *start, const int *size, const int *dim,
359 int element) {
360
361 /* offsets for indices in input grid */
362 auto const m_in_offset = element * (dim[2] - size[2]);
363 auto const s_in_offset = element * (dim[2] * (dim[1] - size[1]));
364 /* offset for slow changing index of output grid */
365 auto const s_out_offset = (element * size[0] * size[1]) - element;
366 /* linear index of in grid */
367 int li_in = element * (start[2] + dim[2] * (start[1] + dim[1] * start[0]));
368
369 for (int s = 0; s < size[0]; s++) { /* mid changing out */
370 auto const m_out_start = element * (s * size[1]);
371 for (int m = 0; m < size[1]; m++) { /* fast changing out */
372 /* linear index of out grid */
373 int li_out = m_out_start + element * m;
374 for (int f = 0; f < size[2]; f++) { /* slow changing out */
375 for (int e = 0; e < element; e++)
376 out[li_out++] = in[li_in++];
377 li_out += s_out_offset;
378 }
379 li_in += m_in_offset;
380 }
381 li_in += s_in_offset;
382 }
383}
384
385} // namespace
386
387/** Communicate the grid data according to the given forward FFT plan.
388 * \param comm MPI communicator.
389 * \param plan FFT communication plan.
390 * \param in input mesh.
391 * \param out output mesh.
392 */
393template <typename FloatType>
394void fft_data_struct<FloatType>::forw_grid_comm(
395 boost::mpi::communicator const &comm, fft_forw_plan<FloatType> const &plan,
396 FloatType const *in, FloatType *out) {
397 for (std::size_t i = 0ul; i < plan.group.size(); i++) {
398 plan.pack_function(in, send_buf.data(), &(plan.send_block[6ul * i]),
399 &(plan.send_block[6ul * i + 3ul]), plan.old_mesh.data(),
400 plan.element);
401
402 if (plan.group[i] != comm.rank()) {
403 fft_sendrecv(send_buf.data(), plan.send_size[i], plan.group[i],
404 recv_buf.data(), plan.recv_size[i], plan.group[i], comm,
406 } else { /* Self communication... */
407 std::swap(send_buf, recv_buf);
408 }
409 fft_unpack_block(recv_buf.data(), out, &(plan.recv_block[6ul * i]),
410 &(plan.recv_block[6ul * i + 3ul]), plan.new_mesh.data(),
411 plan.element);
412 }
413}
414
415/** Communicate the grid data according to the given backward FFT plan.
416 * \param comm MPI communicator.
417 * \param plan_f Forward FFT plan.
418 * \param plan_b Backward FFT plan.
419 * \param in input mesh.
420 * \param out output mesh.
421 */
422template <typename FloatType>
423void fft_data_struct<FloatType>::back_grid_comm(
424 boost::mpi::communicator const &comm,
425 fft_forw_plan<FloatType> const &plan_f,
426 fft_back_plan<FloatType> const &plan_b, FloatType const *in,
427 FloatType *out) {
428 /* Back means: Use the send/receive stuff from the forward plan but
429 replace the receive blocks by the send blocks and vice
430 versa. Attention then also new_mesh and old_mesh are exchanged */
431
432 for (std::size_t i = 0ul; i < plan_f.group.size(); i++) {
433 plan_b.pack_function(in, send_buf.data(), &(plan_f.recv_block[6ul * i]),
434 &(plan_f.recv_block[6ul * i + 3ul]),
435 plan_f.new_mesh.data(), plan_f.element);
436
437 if (plan_f.group[i] != comm.rank()) { /* send first, receive second */
438 fft_sendrecv(send_buf.data(), plan_f.recv_size[i], plan_f.group[i],
439 recv_buf.data(), plan_f.send_size[i], plan_f.group[i], comm,
441 } else { /* Self communication... */
442 std::swap(send_buf, recv_buf);
443 }
444 fft_unpack_block(recv_buf.data(), out, &(plan_f.send_block[6ul * i]),
445 &(plan_f.send_block[6ul * i + 3ul]),
446 plan_f.old_mesh.data(), plan_f.element);
447 }
448}
449
450/** Calculate 'best' mapping between a 2D and 3D grid.
451 * Required for the communication from 3D regular domain
452 * decomposition to 2D regular row decomposition.
453 * The dimensions of the 2D grid are resorted, if necessary, in a way
454 * that they are multiples of the 3D grid dimensions.
455 * \param g3d 3D grid.
456 * \param g2d 2D grid.
457 * \return index of the row direction [0,1,2].
458 */
459int map_3don2d_grid(int const g3d[3], int g2d[3]) {
460 int row_dir = -1;
461 /* trivial case */
462 if (g3d[2] == 1) {
463 return 2;
464 }
465 if (g2d[0] % g3d[0] == 0) {
466 if (g2d[1] % g3d[1] == 0) {
467 row_dir = 2;
468 } else if (g2d[1] % g3d[2] == 0) {
469 row_dir = 1;
470 g2d[2] = g2d[1];
471 g2d[1] = 1;
472 }
473 } else if (g2d[0] % g3d[1] == 0) {
474 if (g2d[1] % g3d[0] == 0) {
475 row_dir = 2;
476 int const tmp = g2d[0];
477 g2d[0] = g2d[1];
478 g2d[1] = tmp;
479 } else if (g2d[1] % g3d[2] == 0) {
480 row_dir = 0;
481 g2d[2] = g2d[1];
482 g2d[1] = g2d[0];
483 g2d[0] = 1;
484 }
485 } else if (g2d[0] % g3d[2] == 0) {
486 if (g2d[1] % g3d[0] == 0) {
487 row_dir = 1;
488 g2d[2] = g2d[0];
489 g2d[0] = g2d[1];
490 g2d[1] = 1;
491 } else if (g2d[1] % g3d[1] == 0) {
492 row_dir = 0;
493 g2d[2] = g2d[0];
494 g2d[0] = 1;
495 }
496 }
497 return row_dir;
498}
499
500/** Calculate most square 2D grid. */
501inline void calc_2d_grid(int n, int grid[3]) {
502 grid[0] = n;
503 grid[1] = 1;
504 grid[2] = 1;
505 for (auto i = static_cast<int>(std::sqrt(n)); i >= 1; i--) {
506 if (n % i == 0) {
507 grid[0] = n / i;
508 grid[1] = i;
509 grid[2] = 1;
510 return;
511 }
512 }
513}
514
515template <typename FloatType>
517 boost::mpi::communicator const &comm, Utils::Vector3i const &ca_mesh_dim,
518 int const *ca_mesh_margin, Utils::Vector3i const &global_mesh_dim,
519 Utils::Vector3d const &global_mesh_off, int &ks_pnum,
520 Utils::Vector3i const &grid) {
521
522 int n_grid[4][3]; /* The four node grids. */
523 int my_pos[4][3]; /* The position of comm.rank() in the node grids. */
524 std::vector<int> n_id[4]; /* linear node identity lists for the node grids. */
525 std::vector<int> n_pos[4]; /* positions of nodes in the node grids. */
526
527 auto const rank = comm.rank();
528 auto const node_pos = Utils::Mpi::cart_coords<3>(comm, rank);
529
530 max_comm_size = 0;
531 max_mesh_size = 0;
532 for (int i = 0; i < 4; i++) {
533 n_id[i].resize(1 * comm.size());
534 n_pos[i].resize(3 * comm.size());
535 }
536
537 /* === node grids === */
538 /* real space node grid (n_grid[0]) */
539 for (int i = 0; i < 3; i++) {
540 n_grid[0][i] = grid[i];
541 my_pos[0][i] = node_pos[i];
542 }
543 for (int i = 0; i < comm.size(); i++) {
544 auto const n_pos_i = Utils::Mpi::cart_coords<3>(comm, i);
545 for (int j = 0; j < 3; ++j) {
546 n_pos[0][3 * i + j] = n_pos_i[j];
547 }
548 auto const lin_ind = Utils::get_linear_index(
549 n_pos_i, {n_grid[0][0], n_grid[0][1], n_grid[0][2]});
550 n_id[0][lin_ind] = i;
551 }
552
553 /* FFT node grids (n_grid[1 - 3]) */
554 calc_2d_grid(comm.size(), n_grid[1]);
555 /* resort n_grid[1] dimensions if necessary */
556 forw[1].row_dir = map_3don2d_grid(n_grid[0], n_grid[1]);
557 forw[0].n_permute = 0;
558 for (int i = 1; i < 4; i++)
559 forw[i].n_permute = (forw[1].row_dir + i) % 3;
560 for (int i = 0; i < 3; i++) {
561 n_grid[2][i] = n_grid[1][(i + 1) % 3];
562 n_grid[3][i] = n_grid[1][(i + 2) % 3];
563 }
564 forw[2].row_dir = (forw[1].row_dir - 1) % 3;
565 forw[3].row_dir = (forw[1].row_dir - 2) % 3;
566
567 /* === communication groups === */
568 /* copy local mesh off real space charge assignment grid */
569 for (int i = 0; i < 3; i++)
570 forw[0].new_mesh[i] = ca_mesh_dim[i];
571
572 for (int i = 1; i < 4; i++) {
573 auto group = find_comm_groups(
574 {n_grid[i - 1][0], n_grid[i - 1][1], n_grid[i - 1][2]},
575 {n_grid[i][0], n_grid[i][1], n_grid[i][2]}, n_id[i - 1],
576 std::span(n_id[i]), std::span(n_pos[i]), my_pos[i], rank);
577 if (not group) {
578 /* try permutation */
579 std::swap(n_grid[i][(forw[i].row_dir + 1) % 3],
580 n_grid[i][(forw[i].row_dir + 2) % 3]);
581
582 group = find_comm_groups(
583 {n_grid[i - 1][0], n_grid[i - 1][1], n_grid[i - 1][2]},
584 {n_grid[i][0], n_grid[i][1], n_grid[i][2]}, std::span(n_id[i - 1]),
585 std::span(n_id[i]), std::span(n_pos[i]), my_pos[i], rank);
586
587 if (not group) {
588 throw std::runtime_error("INTERNAL ERROR: fft_find_comm_groups error");
589 }
590 }
591
592 forw[i].group = group.value();
593
594 forw[i].send_block.resize(6 * forw[i].group.size());
595 forw[i].send_size.resize(forw[i].group.size());
596 forw[i].recv_block.resize(6 * forw[i].group.size());
597 forw[i].recv_size.resize(forw[i].group.size());
598
599 forw[i].new_size = calc_local_mesh(
600 my_pos[i], n_grid[i], global_mesh_dim.data(), global_mesh_off.data(),
601 forw[i].new_mesh.data(), forw[i].start.data());
602 permute_ifield(forw[i].new_mesh.data(), 3, -(forw[i].n_permute));
603 permute_ifield(forw[i].start.data(), 3, -(forw[i].n_permute));
604 forw[i].n_ffts = forw[i].new_mesh[0] * forw[i].new_mesh[1];
605
606 /* === send/recv block specifications === */
607 for (std::size_t j = 0ul; j < forw[i].group.size(); j++) {
608 /* send block: comm.rank() to comm-group-node i (identity: node) */
609 int node = forw[i].group[j];
610 forw[i].send_size[j] = calc_send_block(
611 my_pos[i - 1], n_grid[i - 1], &(n_pos[i][3 * node]), n_grid[i],
612 global_mesh_dim.data(), global_mesh_off.data(),
613 &(forw[i].send_block[6ul * j]));
614 permute_ifield(&(forw[i].send_block[6ul * j]), 3,
615 -(forw[i - 1].n_permute));
616 permute_ifield(&(forw[i].send_block[6ul * j + 3ul]), 3,
617 -(forw[i - 1].n_permute));
618 if (forw[i].send_size[j] > max_comm_size)
619 max_comm_size = forw[i].send_size[j];
620 /* First plan send blocks have to be adjusted, since the CA grid
621 may have an additional margin outside the actual domain of the
622 node */
623 if (i == 1) {
624 for (std::size_t k = 0ul; k < 3ul; k++)
625 forw[1].send_block[6ul * j + k] += ca_mesh_margin[2ul * k];
626 }
627 /* recv block: comm.rank() from comm-group-node i (identity: node) */
628 forw[i].recv_size[j] = calc_send_block(
629 my_pos[i], n_grid[i], &(n_pos[i - 1][3 * node]), n_grid[i - 1],
630 global_mesh_dim.data(), global_mesh_off.data(),
631 &(forw[i].recv_block[6ul * j]));
632 permute_ifield(&(forw[i].recv_block[6ul * j]), 3, -(forw[i].n_permute));
633 permute_ifield(&(forw[i].recv_block[6ul * j + 3ul]), 3,
634 -(forw[i].n_permute));
635 if (forw[i].recv_size[j] > max_comm_size)
636 max_comm_size = forw[i].recv_size[j];
637 }
638
639 for (std::size_t j = 0ul; j < 3ul; j++)
640 forw[i].old_mesh[j] = forw[i - 1].new_mesh[j];
641 if (i == 1) {
642 forw[i].element = 1;
643 } else {
644 forw[i].element = 2;
645 for (std::size_t j = 0ul; j < forw[i].group.size(); j++) {
646 forw[i].send_size[j] *= 2;
647 forw[i].recv_size[j] *= 2;
648 }
649 }
650 }
651
652 /* Factor 2 for complex fields */
653 max_comm_size *= 2;
654 max_mesh_size = Utils::product(ca_mesh_dim);
655 for (int i = 1; i < 4; i++)
656 if (2 * forw[i].new_size > max_mesh_size)
657 max_mesh_size = 2 * forw[i].new_size;
658
659 /* === pack function === */
660 for (int i = 1; i < 4; i++) {
661 forw[i].pack_function = pack_block_permute2;
662 }
663 ks_pnum = 6;
664 if (forw[1].row_dir == 2) {
665 forw[1].pack_function = fft_pack_block;
666 ks_pnum = 4;
667 } else if (forw[1].row_dir == 1) {
668 forw[1].pack_function = pack_block_permute1;
669 ks_pnum = 5;
670 }
671
672 send_buf.resize(max_comm_size);
673 recv_buf.resize(max_comm_size);
674 data_buf.resize(max_mesh_size);
675 auto *c_data = (typename fftw<FloatType>::complex *)(data_buf.data());
676
677 /* === FFT Routines (Using FFTW / RFFTW package)=== */
678 for (int i = 1; i < 4; i++) {
679 if (init_tag) {
680 forw[i].destroy_plan();
681 }
682 forw[i].dir = FFTW_FORWARD;
683 forw[i].plan_handle = fftw<FloatType>::plan_many_dft(
684 1, &forw[i].new_mesh[2], forw[i].n_ffts, c_data, nullptr, 1,
685 forw[i].new_mesh[2], c_data, nullptr, 1, forw[i].new_mesh[2],
686 forw[i].dir, FFTW_PATIENT);
687 assert(forw[i].plan_handle);
688 }
689
690 /* === The BACK Direction === */
691 /* this is needed because slightly different functions are used */
692 for (int i = 1; i < 4; i++) {
693 if (init_tag) {
694 back[i].destroy_plan();
695 }
696 back[i].dir = FFTW_BACKWARD;
697 back[i].plan_handle = fftw<FloatType>::plan_many_dft(
698 1, &forw[i].new_mesh[2], forw[i].n_ffts, c_data, nullptr, 1,
699 forw[i].new_mesh[2], c_data, nullptr, 1, forw[i].new_mesh[2],
700 back[i].dir, FFTW_PATIENT);
701 back[i].pack_function = pack_block_permute1;
702 assert(back[i].plan_handle);
703 }
704 if (forw[1].row_dir == 2) {
705 back[1].pack_function = fft_pack_block;
706 } else if (forw[1].row_dir == 1) {
707 back[1].pack_function = pack_block_permute2;
708 }
709
710 init_tag = true;
711
712 return max_mesh_size;
713}
714
715template <typename FloatType>
717 boost::mpi::communicator const &comm, FloatType *data) {
718 /* ===== first direction ===== */
719
720 auto *c_data = (typename fftw<FloatType>::complex *)data;
721 auto *c_data_buf = (typename fftw<FloatType>::complex *)data_buf.data();
722
723 /* communication to current dir row format (in is data) */
724 forw_grid_comm(comm, forw[1], data, data_buf.data());
725
726 /* complexify the real data array (in is data_buf) */
727 for (int i = 0; i < forw[1].new_size; i++) {
728 data[2 * i + 0] = data_buf[i]; /* real value */
729 data[2 * i + 1] = FloatType(0); /* complex value */
730 }
731 /* perform FFT (in/out is data)*/
732 fftw<FloatType>::execute_dft(forw[1].plan_handle, c_data, c_data);
733 /* ===== second direction ===== */
734 /* communication to current dir row format (in is data) */
735 forw_grid_comm(comm, forw[2], data, data_buf.data());
736 /* perform FFT (in/out is data_buf) */
737 fftw<FloatType>::execute_dft(forw[2].plan_handle, c_data_buf, c_data_buf);
738 /* ===== third direction ===== */
739 /* communication to current dir row format (in is data_buf) */
740 forw_grid_comm(comm, forw[3], data_buf.data(), data);
741 /* perform FFT (in/out is data)*/
742 fftw<FloatType>::execute_dft(forw[3].plan_handle, c_data, c_data);
743
744 /* REMARK: Result has to be in data. */
745}
746
747template <typename FloatType>
749 boost::mpi::communicator const &comm, FloatType *data, bool check_complex) {
750
751 auto *c_data = (typename fftw<FloatType>::complex *)data;
752 auto *c_data_buf = (typename fftw<FloatType>::complex *)data_buf.data();
753
754 /* ===== third direction ===== */
755
756 /* perform FFT (in is data) */
757 fftw<FloatType>::execute_dft(back[3].plan_handle, c_data, c_data);
758 /* communicate (in is data)*/
759 back_grid_comm(comm, forw[3], back[3], data, data_buf.data());
760
761 /* ===== second direction ===== */
762 /* perform FFT (in is data_buf) */
763 fftw<FloatType>::execute_dft(back[2].plan_handle, c_data_buf, c_data_buf);
764 /* communicate (in is data_buf) */
765 back_grid_comm(comm, forw[2], back[2], data_buf.data(), data);
766
767 /* ===== first direction ===== */
768 /* perform FFT (in is data) */
769 fftw<FloatType>::execute_dft(back[1].plan_handle, c_data, c_data);
770 /* throw away the (hopefully) empty complex component (in is data) */
771 for (int i = 0; i < forw[1].new_size; i++) {
772 data_buf[i] = data[2 * i]; /* real value */
773 // Vincent:
774 if (check_complex and std::abs(data[2 * i + 1]) > 1e-5) {
775 printf("Complex value is not zero (i=%d,data=%g)!!!\n", i,
776 data[2 * i + 1]);
777 if (i > 100)
778 throw std::runtime_error("Complex value is not zero");
779 }
780 }
781 /* communicate (in is data_buf) */
782 back_grid_comm(comm, forw[1], back[1], data_buf.data(), data);
783
784 /* REMARK: Result has to be in data. */
785}
786
787template <typename FloatType> void fft_plan<FloatType>::destroy_plan() {
788 if (plan_handle) {
790 plan_handle = nullptr;
791 }
792}
793
794template <class FloatType>
795FloatType *allocator<FloatType>::allocate(const std::size_t n) const {
796 if (n == 0) {
797 return nullptr;
798 }
799 if (n > std::numeric_limits<std::size_t>::max() / sizeof(FloatType)) {
800 throw std::bad_array_new_length();
801 }
802 void *const pv = fftw<FloatType>::malloc(n * sizeof(FloatType));
803 if (!pv) {
804 throw std::bad_alloc();
805 }
806 return static_cast<FloatType *>(pv);
807}
808
809template <class FloatType>
810void allocator<FloatType>::deallocate(FloatType *const p,
811 std::size_t) const noexcept {
812 fftw<FloatType>::free(static_cast<void *>(p));
813}
814
815template struct allocator<float>;
816template struct allocator<double>;
817
818template struct fft_plan<float>;
819template struct fft_plan<double>;
820
821template struct fft_forw_plan<float>;
822template struct fft_forw_plan<double>;
823
824template struct fft_back_plan<float>;
825template struct fft_back_plan<double>;
826
827template struct fft_data_struct<float>;
828template struct fft_data_struct<double>;
829
830} // namespace fft
Vector implementation and trait types for boost qvm interoperability.
static double * block(double *p, std::size_t index, std::size_t size)
Definition elc.cpp:172
static void fft_sendrecv(T const *const sendbuf, int scount, int dest, T *const recvbuf, int rcount, int source, boost::mpi::communicator const &comm, int tag)
Definition fft.cpp:65
#define REQ_FFT_BACK
Tag for communication in back_grid_comm()
Definition fft.cpp:61
#define REQ_FFT_FORW
Tag for communication in forw_grid_comm()
Definition fft.cpp:59
Routines, row decomposition, data structures and communication for the 3D-FFT.
T product(Vector< T, N > const &v)
Definition Vector.hpp:374
void permute_ifield(int *field, int size, int permute)
permute an integer array field of size size about permute positions.
int get_linear_index(int a, int b, int c, const Vector3i &adim, MemoryOrder memory_order=MemoryOrder::COLUMN_MAJOR)
get the linear index from the position (a,b,c) in a 3D grid of dimensions adim.
Definition index.hpp:43
void pack_block_permute2(FloatType const *const in, FloatType *const out, const int *start, const int *size, const int *dim, int element)
Pack a block with dimensions size[0] * size[1] * size[2] starting at start of an input 3D-grid with d...
Definition fft.cpp:357
void pack_block_permute1(FloatType const *const in, FloatType *const out, const int *start, const int *size, const int *dim, int element)
Pack a block with dimensions size[0] * size[1] * size[2] starting at start of an input 3D-grid with d...
Definition fft.cpp:308
int calc_send_block(const int *pos1, const int *grid1, const int *pos2, const int *grid2, const int *mesh, const double *mesh_off, int *block)
Calculate a send (or recv.) block for grid communication during a decomposition change.
Definition fft.cpp:266
int calc_local_mesh(const int *n_pos, const int *n_grid, const int *mesh, const double *mesh_off, int *loc_mesh, int *start)
Calculate the local fft mesh.
Definition fft.cpp:222
Definition fft.cpp:74
std::optional< std::vector< int > > find_comm_groups(Utils::Vector3i const &grid1, Utils::Vector3i const &grid2, std::span< int const > node_list1, std::span< int > node_list2, std::span< int > pos, std::span< int > my_pos, int rank)
This ugly function does the bookkeeping: which nodes have to communicate to each other,...
Definition fft.cpp:117
void calc_2d_grid(int n, int grid[3])
Calculate most square 2D grid.
Definition fft.cpp:501
int map_3don2d_grid(int const g3d[3], int g2d[3])
Calculate 'best' mapping between a 2D and 3D grid.
Definition fft.cpp:459
void fft_pack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Pack a 3D-block of size size starting at start of an input 3D-grid in with dimension dim into an outp...
Definition packing.hpp:44
void fft_unpack_block(FloatType const *const in, FloatType *const out, int const *start, int const *size, int const *dim, int element)
Unpack a 3D-block in of size size into an output 3D-grid out of size dim starting at position start.
Definition packing.hpp:83
DEVICE_QUALIFIER constexpr pointer data() noexcept
Definition Array.hpp:120
Aligned allocator for FFT data.
Definition vector.hpp:33
FloatType * allocate(std::size_t n) const
Definition fft.cpp:795
void deallocate(FloatType *p, std::size_t) const noexcept
Definition fft.cpp:810
Plan for a backward 1D FFT of a flattened 3D array.
Definition fft.hpp:116
Information about the three one dimensional FFTs and how the nodes have to communicate inbetween.
Definition fft.hpp:125
void forward_fft(boost::mpi::communicator const &comm, FloatType *data)
Perform an in-place forward 3D FFT.
Definition fft.cpp:716
void backward_fft(boost::mpi::communicator const &comm, FloatType *data, bool check_complex)
Perform an in-place backward 3D FFT.
Definition fft.cpp:748
int initialize_fft(boost::mpi::communicator const &comm, Utils::Vector3i const &ca_mesh_dim, int const *ca_mesh_margin, Utils::Vector3i const &global_mesh_dim, Utils::Vector3d const &global_mesh_off, int &ks_pnum, Utils::Vector3i const &grid)
Initialize everything connected to the 3D-FFT.
Definition fft.cpp:516
Plan for a forward 1D FFT of a flattened 3D array.
Definition fft.hpp:82
void destroy_plan()
Definition fft.cpp:787
fftwf_complex complex
Definition fft.cpp:85
static auto constexpr malloc
Definition fft.cpp:81
static auto constexpr free
Definition fft.cpp:82
fftw_complex complex
Definition fft.cpp:77
static auto constexpr execute_dft
Definition fft.cpp:80
static auto constexpr destroy_plan
Definition fft.cpp:79
static auto constexpr plan_many_dft
Definition fft.cpp:78