3#include "../kokkos/kokkos_wrapper.hpp"
8template <
typename T,
int Rows,
int Cols >
12 static constexpr int rows = Rows;
13 static constexpr int cols = Cols;
15 static_assert( Rows > 0 && Cols > 0,
"Matrix dimensions must be positive" );
17 KOKKOS_INLINE_FUNCTION
21 static_assert( Rows == 3 && Cols == 3,
"This constructor is only for 3x3 matrices" );
23 mat.
data[0][0] = row0( 0 );
24 mat.
data[0][1] = row0( 1 );
25 mat.
data[0][2] = row0( 2 );
26 mat.
data[1][0] = row1( 0 );
27 mat.
data[1][1] = row1( 1 );
28 mat.
data[1][2] = row1( 2 );
29 mat.
data[2][0] = row2( 0 );
30 mat.
data[2][1] = row2( 1 );
31 mat.
data[2][2] = row2( 2 );
35 KOKKOS_INLINE_FUNCTION
38 static_assert( Rows == 2 && Cols == 2,
"This constructor is only for 2x2 matrices" );
40 mat.
data[0][0] = col0( 0 );
41 mat.
data[0][1] = col1( 0 );
42 mat.
data[1][0] = col0( 1 );
43 mat.
data[1][1] = col1( 1 );
47 KOKKOS_INLINE_FUNCTION
51 static_assert( Rows == 3 && Cols == 3,
"This constructor is only for 3x3 matrices" );
53 mat.
data[0][0] = col0( 0 );
54 mat.
data[0][1] = col1( 0 );
55 mat.
data[0][2] = col2( 0 );
56 mat.
data[1][0] = col0( 1 );
57 mat.
data[1][1] = col1( 1 );
58 mat.
data[1][2] = col2( 1 );
59 mat.
data[2][0] = col0( 2 );
60 mat.
data[2][1] = col1( 2 );
61 mat.
data[2][2] = col2( 2 );
65 KOKKOS_INLINE_FUNCTION
68 static_assert( Rows == 3 && Cols == 3,
"This constructor is only for 3x3 matrices" );
72 mat.
data[0][d] = col( 0 );
73 mat.
data[1][d] = col( 1 );
74 mat.
data[2][d] = col( 2 );
78 KOKKOS_INLINE_FUNCTION
82 for (
int i = 0; i < Rows; ++i )
89 KOKKOS_INLINE_FUNCTION
92 KOKKOS_INLINE_FUNCTION
96 template <
int RHSRows,
int RHSCols >
99 static_assert( Cols == RHSRows,
"Matrix dimensions do not match" );
102 for (
int i = 0; i < Rows; ++i )
104 for (
int j = 0; j < RHSCols; ++j )
106 result( i, j ) = T( 0 );
107 for (
int k = 0; k < Cols; ++k )
109 result( i, j ) +=
data[i][k] * rhs( k, j );
117 KOKKOS_INLINE_FUNCTION
121 for (
int i = 0; i < Rows; ++i )
124 for (
int j = 0; j < Cols; ++j )
126 result( i ) +=
data[i][j] * vec( j );
132 KOKKOS_INLINE_FUNCTION
136 for (
int i = 0; i < Rows; ++i )
138 for (
int j = 0; j < Cols; ++j )
140 result( i, j ) =
data[i][j] * scalar;
146 KOKKOS_INLINE_FUNCTION
149 for (
int i = 0; i < Rows; ++i )
151 for (
int j = 0; j < Cols; ++j )
159 KOKKOS_INLINE_FUNCTION
163 for (
int i = 0; i < Rows; ++i )
165 for (
int j = 0; j < Cols; ++j )
173 KOKKOS_INLINE_FUNCTION
176 for (
int i = 0; i < Rows; ++i )
178 for (
int j = 0; j < Cols; ++j )
186 KOKKOS_INLINE_FUNCTION
190 for (
int i = 0; i < Rows; ++i )
192 for (
int j = 0; j < Cols; ++j )
194 result( j, i ) =
data[i][j];
200 KOKKOS_INLINE_FUNCTION
203 for (
int i = 0; i < Rows; ++i )
205 for (
int j = 0; j < Cols; ++j )
212 KOKKOS_INLINE_FUNCTION
215 for (
int i = 0; i < Rows; ++i )
217 for (
int j = 0; j < Cols; ++j )
225 KOKKOS_INLINE_FUNCTION
229 for (
int i = 0; i < Rows; ++i )
231 for (
int j = 0; j < Cols; ++j )
239 KOKKOS_INLINE_FUNCTION
242 if constexpr ( Rows == 2 && Cols == 2 )
246 else if constexpr ( Rows == 3 && Cols == 3 )
254 static_assert( Rows == -1,
"det() only implemented for 2x2 and 3x3 matrices" );
258 KOKKOS_INLINE_FUNCTION
261 if constexpr ( Rows == 2 && Cols == 2 )
265 Kokkos::abort(
"Singular matrix" );
266 const T invDet = T( 1 ) / d;
267 return { { {
data[1][1] * invDet, -
data[0][1] * invDet }, { -
data[1][0] * invDet,
data[0][0] * invDet } } };
269 else if constexpr ( Rows == 3 && Cols == 3 )
274 Kokkos::abort(
"Singular matrix" );
276 const T
id = T( 1 ) / d;
294 static_assert( Rows == -1,
"inv() only implemented for 2x2 and 3x3 matrices" );
298 KOKKOS_INLINE_FUNCTION
301 if constexpr ( Rows == 2 && Cols == 2 )
305 Kokkos::abort(
"Singular matrix" );
306 const T invDet = T( 1 ) / d;
307 return { { {
data[1][1] * invDet, -
data[0][1] * invDet }, { -
data[1][0] * invDet,
data[0][0] * invDet } } };
309 else if constexpr ( Rows == 3 && Cols == 3 )
314 Kokkos::abort(
"Singular matrix" );
316 const T
id = T( 1 ) / d;
334 static_assert( Rows == -1,
"inv() only implemented for 2x2 and 3x3 matrices" );
338 KOKKOS_INLINE_FUNCTION
341 if constexpr ( Rows == 2 && Cols == 2 )
344 Kokkos::abort(
"Singular matrix" );
345 const T invDet = T( 1 ) /
det;
346 return { { {
data[1][1] * invDet, -
data[0][1] * invDet }, { -
data[1][0] * invDet,
data[0][0] * invDet } } };
348 else if constexpr ( Rows == 3 && Cols == 3 )
352 Kokkos::abort(
"Singular matrix" );
354 const T
id = T( 1 ) /
det;
372 static_assert( Rows == -1,
"inv() only implemented for 2x2 and 3x3 matrices" );
376 KOKKOS_INLINE_FUNCTION
380 for (
int i = 0; i < Rows; ++i )
382 for (
int j = 0; j < Cols; ++j )
384 result( i, j ) = ( i == j ) ?
data[i][j] : T( 0 );
391template <
typename T,
int Rows,
int Cols >
394 for (
int i = 0; i < A.
rows; ++i )
396 for (
int j = 0; j < A.
cols; ++j )
398 os << A( i, j ) <<
" ";
std::ostream & operator<<(std::ostream &os, const Mat< T, Rows, Cols > &A)
Definition mat.hpp:392
Mat< T, Rows, RHSCols > operator*(const Mat< T, RHSRows, RHSCols > &rhs) const
Definition mat.hpp:97
Mat & operator=(const Mat &mat)
Definition mat.hpp:174
constexpr Mat inv_transposed() const
Definition mat.hpp:299
constexpr T det() const
Definition mat.hpp:240
void fill(const T value)
Definition mat.hpp:201
static constexpr Mat from_col_vecs(const Vec< T, Rows > &col0, const Vec< T, Rows > &col1, const Vec< T, Rows > &col2)
Definition mat.hpp:49
T & operator()(int i, int j)
Definition mat.hpp:90
constexpr Mat diagonal() const
Definition mat.hpp:377
static constexpr int cols
Definition mat.hpp:13
T data[Rows][Cols]
Definition mat.hpp:11
static constexpr Mat diagonal_from_vec(const Vec< T, Rows > &diagonal)
Definition mat.hpp:79
static constexpr Mat from_col_vecs(const Vec< T, Rows > &col0, const Vec< T, Rows > &col1)
Definition mat.hpp:36
Mat operator+(const Mat &mat)
Definition mat.hpp:160
Mat & operator+=(const Mat &mat)
Definition mat.hpp:147
constexpr Mat inv_transposed(const T &det) const
Definition mat.hpp:339
Mat operator*(const T &scalar) const
Definition mat.hpp:133
constexpr Mat< T, Cols, Rows > transposed() const
Definition mat.hpp:187
Vec< T, Rows > operator*(const Vec< T, Cols > &vec) const
Definition mat.hpp:118
static constexpr Mat from_single_col_vec(const Vec< T, Cols > &col, const int d)
Definition mat.hpp:66
Mat & hadamard_product(const Mat &mat)
Definition mat.hpp:213
const T & operator()(int i, int j) const
Definition mat.hpp:93
static constexpr int rows
Definition mat.hpp:12
constexpr Mat inv() const
Definition mat.hpp:259
T double_contract(const Mat &mat)
Definition mat.hpp:226
static constexpr Mat from_row_vecs(const Vec< T, Cols > &row0, const Vec< T, Cols > &row1, const Vec< T, Cols > &row2)
Definition mat.hpp:19