ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
WalberlaCheckpoint.hpp
Go to the documentation of this file.
1/*
2 * Copyright (C) 2021-2023 The ESPResSo project
3 *
4 * This file is part of ESPResSo.
5 *
6 * ESPResSo is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * ESPResSo is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20#pragma once
21
22#include "config/config.hpp"
23
24#ifdef WALBERLA
25
27
28#include <utils/Vector.hpp>
29
30#include <boost/mpi/collectives/broadcast.hpp>
31
32#include <fstream>
33#include <ios>
34#include <memory>
35#include <sstream>
36#include <stdexcept>
37#include <string>
38#include <vector>
39
41
42enum class CptMode : int {
43 ascii = 0,
44 binary = 1,
47};
48
49/** Inject code for unit tests. */
50inline void unit_test_handle(int mode) {
51 switch (mode) {
52 case static_cast<int>(CptMode::ascii):
53 case static_cast<int>(CptMode::binary):
54 return;
55 case static_cast<int>(CptMode::unit_test_runtime_error):
56 throw std::runtime_error("unit test error");
57 case static_cast<int>(CptMode::unit_test_ios_failure):
58 throw std::ios_base::failure("unit test error");
59 default:
60 throw std::domain_error("Unknown mode " + std::to_string(mode));
61 }
62}
63
64/** Handle for a checkpoint file. */
66private:
67 bool m_binary;
68
69public:
70 std::fstream stream;
71
72 CheckpointFile(std::string const &filename, std::ios_base::openmode mode,
73 bool binary) {
74 m_binary = binary;
75 auto flags = mode;
76 if (m_binary)
77 flags |= std::ios_base::binary;
78 stream.open(filename, flags);
79 }
80
81 ~CheckpointFile() = default;
82
83 template <typename T> void write(T const &value) {
84 if (m_binary) {
85 stream.write(reinterpret_cast<const char *>(&value), sizeof(T));
86 } else {
87 stream << value << "\n";
88 }
89 }
90
91 template <typename T> void write(std::vector<T> const &vector) {
92 if (m_binary) {
93 stream.write(reinterpret_cast<const char *>(vector.data()),
94 vector.size() * sizeof(T));
95 } else {
96 for (auto const &value : vector) {
97 stream << value << "\n";
98 }
99 }
100 }
101
102 template <typename T, std::size_t N>
103 void write(Utils::Vector<T, N> const &vector) {
104 if (m_binary) {
105 stream.write(reinterpret_cast<const char *>(vector.data()),
106 N * sizeof(T));
107 } else {
108 stream << Utils::Vector<T, N>::formatter(" ") << vector << "\n";
109 }
110 }
111
112 template <typename T> void read(T &value) {
113 if (m_binary) {
114 stream.read(reinterpret_cast<char *>(&value), sizeof(T));
115 } else {
116 stream >> value;
117 }
118 }
119
120 template <typename T, std::size_t N> void read(Utils::Vector<T, N> &vector) {
121 if (m_binary) {
122 stream.read(reinterpret_cast<char *>(vector.data()), N * sizeof(T));
123 } else {
124 for (auto &value : vector) {
125 stream >> value;
126 }
127 }
128 }
129
130 template <typename T> void read(std::vector<T> &vector) {
131 if (m_binary) {
132 stream.read(reinterpret_cast<char *>(vector.data()),
133 vector.size() * sizeof(T));
134 } else {
135 for (auto &value : vector) {
136 stream >> value;
137 }
138 }
139 }
140};
141
142template <typename F1, typename F2, typename F3>
143void load_checkpoint_common(Context const &context, std::string const classname,
144 std::string const &filename, int mode,
145 F1 const read_metadata, F2 const read_data,
146 F3 const on_success) {
147 auto const err_msg =
148 std::string("Error while reading " + classname + " checkpoint: ");
149 auto const binary = mode == static_cast<int>(CptMode::binary);
150 auto const &comm = context.get_comm();
151 auto const is_head_node = context.is_head_node();
152
153 // open file and set exceptions
154 CheckpointFile cpfile(filename, std::ios_base::in, binary);
155 if (!cpfile.stream) {
156 if (is_head_node) {
157 throw std::runtime_error(err_msg + "could not open file " + filename);
158 }
159 return;
160 }
161 cpfile.stream.exceptions(std::ios_base::failbit | std::ios_base::badbit);
162
163 try {
166 comm.barrier();
167 on_success();
168 // check EOF
169 if (!binary) {
170 if (cpfile.stream.peek() == '\n') {
171 static_cast<void>(cpfile.stream.get());
172 }
173 }
174 if (cpfile.stream.peek() != EOF) {
175 throw std::runtime_error(err_msg + "extra data found, expected EOF.");
176 }
177 } catch (std::ios_base::failure const &) {
178 auto const eof_error = cpfile.stream.eof();
179 cpfile.stream.close();
180 if (eof_error) {
181 if (is_head_node) {
182 throw std::runtime_error(err_msg + "EOF found.");
183 }
184 return;
185 }
186 if (is_head_node) {
187 throw std::runtime_error(err_msg + "incorrectly formatted data.");
188 }
189 return;
190 } catch (std::runtime_error const &err) {
191 cpfile.stream.close();
192 if (is_head_node) {
193 throw std::runtime_error(err_msg + err.what());
194 }
195 return;
196 }
197}
198
199template <typename F1, typename F2, typename F3>
200void save_checkpoint_common(Context const &context, std::string const classname,
201 std::string const &filename, int mode,
202 F1 const write_metadata, F2 const write_data,
203 F3 const on_failure) {
204 auto const err_msg =
205 std::string("Error while writing " + classname + " checkpoint: ");
206 auto const binary = mode == static_cast<int>(CptMode::binary);
207 auto const &comm = context.get_comm();
208 auto const is_head_node = context.is_head_node();
209
210 // open file and set exceptions
211 auto failure = false;
212 std::shared_ptr<CheckpointFile> cpfile;
213 if (is_head_node) {
214 cpfile =
215 std::make_shared<CheckpointFile>(filename, std::ios_base::out, binary);
216 failure = !cpfile->stream;
217 boost::mpi::broadcast(comm, failure, 0);
218 if (failure) {
219 throw std::runtime_error(err_msg + "could not open file " + filename);
220 }
221 cpfile->stream.exceptions(std::ios_base::failbit | std::ios_base::badbit);
222 if (!binary) {
223 cpfile->stream.precision(16);
224 cpfile->stream << std::fixed;
225 }
226 } else {
227 boost::mpi::broadcast(comm, failure, 0);
228 if (failure) {
229 return;
230 }
231 }
232
233 try {
234 write_metadata(cpfile, context);
235 write_data(cpfile, context);
236 } catch (std::exception const &error) {
237 on_failure(cpfile, context);
238 if (is_head_node) {
239 cpfile->stream.close();
240 if (dynamic_cast<std::ios_base::failure const *>(&error)) {
241 throw std::runtime_error(err_msg + "could not write to " + filename);
242 }
243 throw;
244 }
245 }
246}
247
248} // namespace ScriptInterface::walberla
249
250#endif // WALBERLA
ScriptInterface::Context decorates ScriptInterface::ObjectHandle objects with a context: a creation p...
Vector implementation and trait types for boost qvm interoperability.
Context of an object handle.
Definition Context.hpp:54
virtual bool is_head_node() const =0
virtual boost::mpi::communicator const & get_comm() const =0
CheckpointFile(std::string const &filename, std::ios_base::openmode mode, bool binary)
void read(Utils::Vector< T, N > &vector)
void write(std::vector< T > const &vector)
void write(Utils::Vector< T, N > const &vector)
This file contains the defaults for ESPResSo.
void save_checkpoint_common(Context const &context, std::string const classname, std::string const &filename, int mode, F1 const write_metadata, F2 const write_data, F3 const on_failure)
void unit_test_handle(int mode)
Inject code for unit tests.
void load_checkpoint_common(Context const &context, std::string const classname, std::string const &filename, int mode, F1 const read_metadata, F2 const read_data, F3 const on_success)
T get_value(Variant const &v)
Extract value of specific type T from a Variant.