ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
WalberlaCheckpoint.hpp
Go to the documentation of this file.
1/*
2 * Copyright (C) 2021-2023 The ESPResSo project
3 *
4 * This file is part of ESPResSo.
5 *
6 * ESPResSo is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * ESPResSo is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20#pragma once
21
22#include <config/config.hpp>
23
24#ifdef ESPRESSO_WALBERLA
25
27
28#include <utils/Vector.hpp>
29
30#include <boost/mpi/collectives/broadcast.hpp>
31
32#include <filesystem>
33#include <fstream>
34#include <ios>
35#include <memory>
36#include <sstream>
37#include <stdexcept>
38#include <string>
39#include <vector>
40
42
43enum class CptMode : int {
44 ascii = 0,
45 binary = 1,
48};
49
50/** Inject code for unit tests. */
51inline void unit_test_handle(int mode) {
52 switch (mode) {
53 case static_cast<int>(CptMode::ascii):
54 case static_cast<int>(CptMode::binary):
55 return;
56 case static_cast<int>(CptMode::unit_test_runtime_error):
57 throw std::runtime_error("unit test error");
58 case static_cast<int>(CptMode::unit_test_ios_failure):
59 throw std::ios_base::failure("unit test error");
60 default:
61 throw std::domain_error("Unknown mode " + std::to_string(mode));
62 }
63}
64
65/** Handle for a checkpoint file. */
67private:
68 bool m_binary;
69
70public:
71 std::fstream stream;
72
73 CheckpointFile(std::filesystem::path const &path,
74 std::ios_base::openmode mode, bool binary) {
75 m_binary = binary;
76 auto flags = mode;
77 if (m_binary)
78 flags |= std::ios_base::binary;
79 stream.open(path, flags);
80 }
81
82 ~CheckpointFile() = default;
83
84 template <typename T> void write(T const &value) {
85 if (m_binary) {
86 stream.write(reinterpret_cast<const char *>(&value), sizeof(T));
87 } else {
88 stream << value << "\n";
89 }
90 }
91
92 template <typename T> void write(std::vector<T> const &vector) {
93 if (m_binary) {
94 stream.write(reinterpret_cast<const char *>(vector.data()),
95 vector.size() * sizeof(T));
96 } else {
97 for (auto const &value : vector) {
98 stream << value << "\n";
99 }
100 }
101 }
102
103 template <typename T, std::size_t N>
104 void write(Utils::Vector<T, N> const &vector) {
105 if (m_binary) {
106 stream.write(reinterpret_cast<const char *>(vector.data()),
107 N * sizeof(T));
108 } else {
109 stream << Utils::Vector<T, N>::formatter(" ") << vector << "\n";
110 }
111 }
112
113 template <typename T> void read(T &value) {
114 if (m_binary) {
115 stream.read(reinterpret_cast<char *>(&value), sizeof(T));
116 } else {
117 stream >> value;
118 }
119 }
120
121 template <typename T, std::size_t N> void read(Utils::Vector<T, N> &vector) {
122 if (m_binary) {
123 stream.read(reinterpret_cast<char *>(vector.data()), N * sizeof(T));
124 } else {
125 for (auto &value : vector) {
126 stream >> value;
127 }
128 }
129 }
130
131 template <typename T> void read(std::vector<T> &vector) {
132 if (m_binary) {
133 stream.read(reinterpret_cast<char *>(vector.data()),
134 vector.size() * sizeof(T));
135 } else {
136 for (auto &value : vector) {
137 stream >> value;
138 }
139 }
140 }
141};
142
143template <typename F1, typename F2, typename F3>
144void load_checkpoint_common(Context const &context, std::string const classname,
145 std::filesystem::path const &path, int mode,
146 F1 const read_metadata, F2 const read_data,
147 F3 const on_success) {
148 auto const err_msg =
149 std::string("Error while reading " + classname + " checkpoint: ");
150 auto const binary = mode == static_cast<int>(CptMode::binary);
151 auto const &comm = context.get_comm();
152 auto const is_head_node = context.is_head_node();
153
154 // open file and set exceptions
155 CheckpointFile cpfile(path, std::ios_base::in, binary);
156 if (!cpfile.stream) {
157 if (is_head_node) {
158 throw std::runtime_error(err_msg + "could not open file " +
159 path.string());
160 }
161 return;
162 }
163 cpfile.stream.exceptions(std::ios_base::failbit | std::ios_base::badbit);
164
165 try {
168 comm.barrier();
169 on_success();
170 // check EOF
171 if (!binary) {
172 if (cpfile.stream.peek() == '\n') {
173 static_cast<void>(cpfile.stream.get());
174 }
175 }
176 if (cpfile.stream.peek() != EOF) {
177 throw std::runtime_error(err_msg + "extra data found, expected EOF.");
178 }
179 } catch (std::ios_base::failure const &) {
180 auto const eof_error = cpfile.stream.eof();
181 cpfile.stream.close();
182 if (eof_error) {
183 if (is_head_node) {
184 throw std::runtime_error(err_msg + "EOF found.");
185 }
186 return;
187 }
188 if (is_head_node) {
189 throw std::runtime_error(err_msg + "incorrectly formatted data.");
190 }
191 return;
192 } catch (std::runtime_error const &err) {
193 cpfile.stream.close();
194 if (is_head_node) {
195 throw std::runtime_error(err_msg + err.what());
196 }
197 return;
198 }
199}
200
201template <typename F1, typename F2, typename F3>
202void save_checkpoint_common(Context const &context, std::string const classname,
203 std::filesystem::path const &path, int mode,
204 F1 const write_metadata, F2 const write_data,
205 F3 const on_failure) {
206 auto const err_msg =
207 std::string("Error while writing " + classname + " checkpoint: ");
208 auto const binary = mode == static_cast<int>(CptMode::binary);
209 auto const &comm = context.get_comm();
210 auto const is_head_node = context.is_head_node();
211
212 // open file and set exceptions
213 auto failure = false;
214 std::shared_ptr<CheckpointFile> cpfile;
215 if (is_head_node) {
216 cpfile = std::make_shared<CheckpointFile>(path, std::ios_base::out, binary);
217 failure = !cpfile->stream;
218 boost::mpi::broadcast(comm, failure, 0);
219 if (failure) {
220 throw std::runtime_error(err_msg + "could not open file " +
221 path.string());
222 }
223 cpfile->stream.exceptions(std::ios_base::failbit | std::ios_base::badbit);
224 if (!binary) {
225 cpfile->stream.precision(16);
226 cpfile->stream << std::fixed;
227 }
228 } else {
229 boost::mpi::broadcast(comm, failure, 0);
230 if (failure) {
231 return;
232 }
233 }
234
235 try {
236 write_metadata(cpfile, context);
237 write_data(cpfile, context);
238 } catch (std::exception const &error) {
239 on_failure(cpfile, context);
240 if (is_head_node) {
241 cpfile->stream.close();
242 if (dynamic_cast<std::ios_base::failure const *>(&error)) {
243 throw std::runtime_error(err_msg + "could not write to " +
244 path.string());
245 }
246 throw;
247 }
248 }
249}
250
251} // namespace ScriptInterface::walberla
252
253#endif // ESPRESSO_WALBERLA
ScriptInterface::Context decorates ScriptInterface::ObjectHandle objects with a context: a creation p...
Vector implementation and trait types for boost qvm interoperability.
Context of an object handle.
Definition Context.hpp:53
virtual bool is_head_node() const =0
virtual boost::mpi::communicator const & get_comm() const =0
CheckpointFile(std::filesystem::path const &path, std::ios_base::openmode mode, bool binary)
void read(Utils::Vector< T, N > &vector)
void write(std::vector< T > const &vector)
void write(Utils::Vector< T, N > const &vector)
cudaStream_t stream[1]
CUDA streams for parallel computing on CPU and GPU.
void save_checkpoint_common(Context const &context, std::string const classname, std::filesystem::path const &path, int mode, F1 const write_metadata, F2 const write_data, F3 const on_failure)
void unit_test_handle(int mode)
Inject code for unit tests.
void load_checkpoint_common(Context const &context, std::string const classname, std::filesystem::path const &path, int mode, F1 const read_metadata, F2 const read_data, F3 const on_success)