ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
matrix.hpp
Go to the documentation of this file.
1/*
2 * Copyright (C) 2010-2022 The ESPResSo project
3 *
4 * This file is part of ESPResSo.
5 *
6 * ESPResSo is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * ESPResSo is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 */
19#ifndef SRC_UTILS_INCLUDE_UTILS_MATRIX_HPP
20#define SRC_UTILS_INCLUDE_UTILS_MATRIX_HPP
21
22/**
23 * @file
24 *
25 * @brief Matrix implementation and trait types
26 * for boost qvm interoperability.
27 */
28
29#include "utils/Array.hpp"
30#include "utils/Vector.hpp"
31#include "utils/flatten.hpp"
32
33#include <algorithm>
34#include <array>
35#include <cassert>
36#include <cstddef>
37#include <type_traits>
38#include <utility>
39
40// These includes need to come first due to ADL reasons.
41// clang-format off
42#include <boost/qvm/mat_operations.hpp>
43#include <boost/qvm/vec_mat_operations.hpp>
44#include <boost/qvm/vec_operations.hpp>
45// clang-format on
46
47#include <boost/qvm/deduce_mat.hpp>
48#include <boost/qvm/deduce_scalar.hpp>
49#include <boost/qvm/deduce_vec.hpp>
50#include <boost/qvm/map_mat_mat.hpp>
51#include <boost/qvm/map_mat_vec.hpp>
52#include <boost/qvm/map_vec_mat.hpp>
53#include <boost/qvm/mat.hpp>
54#include <boost/qvm/mat_access.hpp>
55#include <boost/qvm/mat_traits.hpp>
56
57namespace Utils {
58
59/**
60 * @brief Matrix representation with static size.
61 * @tparam T The data type.
62 * @tparam Rows Number of rows.
63 * @tparam Cols Number of columns.
64 */
65template <typename T, std::size_t Rows, std::size_t Cols> struct Matrix {
67 using pointer = typename container::pointer;
69 using iterator = typename container::iterator;
74
76
77private:
79 template <class Archive> void serialize(Archive &ar, const unsigned int) {
80 ar & m_data;
81 }
82
83public:
84 Matrix() = default;
85 Matrix(std::initializer_list<T> init_list) {
86 assert(init_list.size() == Rows * Cols);
87 std::copy(init_list.begin(), init_list.end(), begin());
88 }
89 Matrix(std::initializer_list<std::initializer_list<T>> init_list) {
90 assert(init_list.size() == Rows);
91 Utils::flatten(init_list, begin());
92 }
93
94 /**
95 * @brief Element access (const).
96 * @param row The row used for access.
97 * @param col The column used for access.
98 * @return The matrix element at row @p row and column @p col.
99 */
100 constexpr value_type operator()(std::size_t row, std::size_t col) const {
101 assert(row < Rows);
102 assert(col < Cols);
103 return m_data[Cols * row + col];
104 }
105 /**
106 * @brief Element access (non const).
107 * @param row The row used for access.
108 * @param col The column used for access.
109 * @return The matrix element at row @p row and column @p col.
110 */
111 constexpr reference operator()(std::size_t row, std::size_t col) {
112 assert(row < Rows);
113 assert(col < Cols);
114 return m_data[Cols * row + col];
115 }
116
117 /**
118 * @brief Access to the underlying data pointer (non const).
119 * @return Pointer to first element of the data.
120 */
121 constexpr pointer data() { return m_data.data(); }
122 /**
123 * @brief Access to the underlying data pointer (non const).
124 * @return Pointer to first element of the data.
125 */
126 constexpr const_pointer data() const noexcept { return m_data.data(); }
127 /**
128 * @brief Iterator access (non const).
129 * @return Returns an iterator to the first element of the matrix.
130 */
131 constexpr iterator begin() noexcept { return m_data.begin(); }
132 /**
133 * @brief Iterator access (const).
134 * @return Returns an iterator to the first element of the matrix.
135 */
136 constexpr const_iterator begin() const noexcept { return m_data.begin(); }
137 /**
138 * @brief Iterator access (non const).
139 * @return Returns an iterator to the element following the last element of
140 * the matrix.
141 */
142 constexpr iterator end() noexcept { return m_data.end(); }
143 /**
144 * @brief Iterator access (non const).
145 * @return Returns an iterator to the element following the last element of
146 * the matrix.
147 */
148 constexpr const_iterator end() const noexcept { return m_data.end(); }
149 /**
150 * @brief Retrieve an entire matrix row.
151 * @tparam R The row index.
152 * @return A vector containing the elements of row @p R.
153 */
154 template <std::size_t R> Vector<T, Cols> row() const {
155 static_assert(R < Rows, "Invalid row index.");
156 return boost::qvm::row<R>(*this);
157 }
158 /**
159 * @brief Retrieve an entire matrix column.
160 * @tparam C The column index.
161 * @return A vector containing the elements of column @p C.
162 */
163 template <std::size_t C> Vector<T, Rows> col() const {
164 static_assert(C < Cols, "Invalid column index.");
165 return boost::qvm::col<C>(*this);
166 }
167 /**
168 * @brief Retrieve the diagonal.
169 * @return Vector containing the diagonal elements of the matrix.
170 */
172 static_assert(Rows == Cols,
173 "Diagonal can only be retrieved from square matrices.");
174 return boost::qvm::diag(*this);
175 }
176 /**
177 * @brief Retrieve the trace.
178 * @return Vector containing the sum of diagonal matrix elements.
179 */
180 T trace() const {
181 auto const d = diagonal();
182 return std::accumulate(d.begin(), d.end(), T{}, std::plus<T>{});
183 }
184
185 /**
186 * @brief Retrieve a transposed copy of the matrix.
187 * @return Transposed matrix.
188 */
190 return boost::qvm::transposed(*this);
191 }
192
193 /**
194 * @brief Retrieve an inverted copy of the matrix.
195 * @return Inverted matrix.
196 */
198 static_assert(Rows == Cols,
199 "Inversion of a non-square matrix not implemented.");
200 return boost::qvm::inverse(*this);
201 }
202 /**
203 * @brief Retrieve the shape of the matrix.
204 * @return Pair containing number of rows and number of columns of the matrix.
205 */
206 constexpr std::pair<std::size_t, std::size_t> shape() const noexcept {
207 return {Rows, Cols};
208 }
209};
210
211using boost::qvm::operator+;
212using boost::qvm::operator+=;
213using boost::qvm::operator-;
214using boost::qvm::operator-=;
215using boost::qvm::operator*;
216using boost::qvm::operator*=;
217using boost::qvm::operator==;
218
219template <typename T, std::size_t M, std::size_t N>
223
224template <typename T, std::size_t Rows, std::size_t Cols>
226 static_assert(Rows == Cols, "Diagonal matrix has to be a square matrix.");
227 return boost::qvm::diag_mat(v);
228}
229
230template <typename T, std::size_t Rows, std::size_t Cols>
232 static_assert(Rows == Cols,
233 "Identity matrix only defined for square matrices.");
234 return boost::qvm::identity_mat<T, Rows>();
235}
236
237} // namespace Utils
238
239namespace boost {
240namespace qvm {
241
242template <typename T, std::size_t Rows, std::size_t Cols>
243struct mat_traits<Utils::Matrix<T, Rows, Cols>> {
245 static int const rows = Rows;
246 static int const cols = Cols;
247 using scalar_type = T;
248
249 template <std::size_t R, std::size_t C>
250 static inline scalar_type read_element(mat_type const &m) {
251 static_assert(R < Rows, "Invalid row index.");
252 static_assert(C < Cols, "Invalid column index.");
253 return m(R, C);
254 }
255
256 template <std::size_t R, std::size_t C>
257 static inline scalar_type &write_element(mat_type &m) {
258 static_assert(R < Rows, "Invalid row index.");
259 static_assert(C < Cols, "Invalid column index.");
260 return m(R, C);
261 }
262
263 static inline scalar_type read_element_idx(std::size_t r, std::size_t c,
264 mat_type const &m) {
265 assert(r < Rows);
266 assert(c < Cols);
267 return m(r, c);
268 }
269 static inline scalar_type &write_element_idx(std::size_t r, std::size_t c,
270 mat_type &m) {
271 assert(r < Rows);
272 assert(c < Cols);
273 return m(r, c);
274 }
275};
276
277template <typename T, typename U>
278struct deduce_vec2<Utils::Matrix<T, 2, 2>, Utils::Vector<U, 2>, 2> {
280};
281
282template <typename T, typename U>
283struct deduce_vec2<Utils::Matrix<T, 3, 3>, Utils::Vector<U, 3>, 3> {
285};
286
287template <typename T, typename U>
288struct deduce_vec2<Utils::Matrix<T, 4, 4>, Utils::Vector<U, 4>, 4> {
290};
291
292template <typename T, typename U>
293struct deduce_vec2<Utils::Matrix<T, 2, 3>, Utils::Vector<U, 3>, 2> {
295};
296
297template <typename T, typename U>
298struct deduce_mat2<Utils::Matrix<T, 3, 3>, Utils::Matrix<U, 3, 3>, 3, 3> {
300};
301
302} // namespace qvm
303} // namespace boost
304#endif // SRC_UTILS_INCLUDE_UTILS_MATRIX_HPP
Array implementation with CUDA support.
Vector implementation and trait types for boost qvm interoperability.
void flatten(Range const &v, OutputIterator out)
Flatten a range of ranges.
Definition flatten.hpp:64
Matrix< T, Rows, Cols > identity_mat()
Definition matrix.hpp:231
Matrix< T, Rows, Cols > diagonal_mat(Utils::Vector< T, Rows > const &v)
Definition matrix.hpp:225
DEVICE_QUALIFIER constexpr pointer data() noexcept
Definition Array.hpp:120
DEVICE_QUALIFIER constexpr iterator begin() noexcept
Definition Array.hpp:128
const value_type & const_reference
Definition Array.hpp:78
const value_type * const_pointer
Definition Array.hpp:82
DEVICE_QUALIFIER constexpr iterator end() noexcept
Definition Array.hpp:140
const value_type * const_iterator
Definition Array.hpp:80
Matrix representation with static size.
Definition matrix.hpp:65
constexpr iterator begin() noexcept
Iterator access (non const).
Definition matrix.hpp:131
Matrix< T, Cols, Rows > transposed() const
Retrieve a transposed copy of the matrix.
Definition matrix.hpp:189
constexpr const_pointer data() const noexcept
Access to the underlying data pointer (non const).
Definition matrix.hpp:126
typename container::pointer pointer
Definition matrix.hpp:67
typename container::const_pointer const_pointer
Definition matrix.hpp:68
container m_data
Definition matrix.hpp:75
Vector< T, Cols > diagonal() const
Retrieve the diagonal.
Definition matrix.hpp:171
Matrix(std::initializer_list< std::initializer_list< T > > init_list)
Definition matrix.hpp:89
typename container::const_reference const_reference
Definition matrix.hpp:73
constexpr const_iterator end() const noexcept
Iterator access (non const).
Definition matrix.hpp:148
T trace() const
Retrieve the trace.
Definition matrix.hpp:180
Vector< T, Cols > row() const
Retrieve an entire matrix row.
Definition matrix.hpp:154
typename container::iterator iterator
Definition matrix.hpp:69
Matrix()=default
Matrix< T, Rows, Cols > inversed() const
Retrieve an inverted copy of the matrix.
Definition matrix.hpp:197
Vector< T, Rows > col() const
Retrieve an entire matrix column.
Definition matrix.hpp:163
constexpr std::pair< std::size_t, std::size_t > shape() const noexcept
Retrieve the shape of the matrix.
Definition matrix.hpp:206
typename container::value_type value_type
Definition matrix.hpp:71
typename container::reference reference
Definition matrix.hpp:72
typename container::const_iterator const_iterator
Definition matrix.hpp:70
friend class boost::serialization::access
Definition matrix.hpp:78
constexpr iterator end() noexcept
Iterator access (non const).
Definition matrix.hpp:142
constexpr pointer data()
Access to the underlying data pointer (non const).
Definition matrix.hpp:121
constexpr value_type operator()(std::size_t row, std::size_t col) const
Element access (const).
Definition matrix.hpp:100
constexpr const_iterator begin() const noexcept
Iterator access (const).
Definition matrix.hpp:136
constexpr reference operator()(std::size_t row, std::size_t col)
Element access (non const).
Definition matrix.hpp:111
Matrix(std::initializer_list< T > init_list)
Definition matrix.hpp:85
static scalar_type read_element_idx(std::size_t r, std::size_t c, mat_type const &m)
Definition matrix.hpp:263
static scalar_type & write_element_idx(std::size_t r, std::size_t c, mat_type &m)
Definition matrix.hpp:269
static scalar_type read_element(mat_type const &m)
Definition matrix.hpp:250
static scalar_type & write_element(mat_type &m)
Definition matrix.hpp:257
typename Utils::Matrix< T, Rows, Cols > mat_type
Definition matrix.hpp:244