Loading...
Searching...
No Matches
mat.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "../kokkos/kokkos_wrapper.hpp"
4#include "./vec.hpp"
5
6namespace terra::dense {
7
8template < typename T, int Rows, int Cols >
9struct Mat
10{
11 T data[Rows][Cols] = {};
12 static constexpr int rows = Rows;
13 static constexpr int cols = Cols;
14
15 static_assert( Rows > 0 && Cols > 0, "Matrix dimensions must be positive" );
16
17 KOKKOS_INLINE_FUNCTION
18 constexpr static Mat
19 from_row_vecs( const Vec< T, Cols >& row0, const Vec< T, Cols >& row1, const Vec< T, Cols >& row2 )
20 {
21 static_assert( Rows == 3 && Cols == 3, "This constructor is only for 3x3 matrices" );
22 Mat mat;
23 mat.data[0][0] = row0( 0 );
24 mat.data[0][1] = row0( 1 );
25 mat.data[0][2] = row0( 2 );
26 mat.data[1][0] = row1( 0 );
27 mat.data[1][1] = row1( 1 );
28 mat.data[1][2] = row1( 2 );
29 mat.data[2][0] = row2( 0 );
30 mat.data[2][1] = row2( 1 );
31 mat.data[2][2] = row2( 2 );
32 return mat;
33 }
34
35 KOKKOS_INLINE_FUNCTION
36 constexpr static Mat from_col_vecs( const Vec< T, Rows >& col0, const Vec< T, Rows >& col1 )
37 {
38 static_assert( Rows == 2 && Cols == 2, "This constructor is only for 2x2 matrices" );
39 Mat mat;
40 mat.data[0][0] = col0( 0 );
41 mat.data[0][1] = col1( 0 );
42 mat.data[1][0] = col0( 1 );
43 mat.data[1][1] = col1( 1 );
44 return mat;
45 }
46
47 KOKKOS_INLINE_FUNCTION
48 constexpr static Mat
49 from_col_vecs( const Vec< T, Rows >& col0, const Vec< T, Rows >& col1, const Vec< T, Rows >& col2 )
50 {
51 static_assert( Rows == 3 && Cols == 3, "This constructor is only for 3x3 matrices" );
52 Mat mat;
53 mat.data[0][0] = col0( 0 );
54 mat.data[0][1] = col1( 0 );
55 mat.data[0][2] = col2( 0 );
56 mat.data[1][0] = col0( 1 );
57 mat.data[1][1] = col1( 1 );
58 mat.data[1][2] = col2( 1 );
59 mat.data[2][0] = col0( 2 );
60 mat.data[2][1] = col1( 2 );
61 mat.data[2][2] = col2( 2 );
62 return mat;
63 }
64
65 KOKKOS_INLINE_FUNCTION
66 constexpr static Mat from_single_col_vec( const Vec< T, Cols >& col, const int d )
67 {
68 static_assert( Rows == 3 && Cols == 3, "This constructor is only for 3x3 matrices" );
69 assert( d < 3 );
70 Mat mat;
71 mat.fill( 0 );
72 mat.data[0][d] = col( 0 );
73 mat.data[1][d] = col( 1 );
74 mat.data[2][d] = col( 2 );
75 return mat;
76 }
77
78 KOKKOS_INLINE_FUNCTION
79 constexpr static Mat diagonal_from_vec( const Vec< T, Rows >& diagonal )
80 {
81 Mat mat;
82 for ( int i = 0; i < Rows; ++i )
83 {
84 mat( i, i ) = diagonal( i );
85 }
86 return mat;
87 }
88
89 KOKKOS_INLINE_FUNCTION
90 T& operator()( int i, int j ) { return data[i][j]; }
91
92 KOKKOS_INLINE_FUNCTION
93 const T& operator()( int i, int j ) const { return data[i][j]; }
94
95 // Matrix-matrix multiplication
96 template < int RHSRows, int RHSCols >
97 KOKKOS_INLINE_FUNCTION Mat< T, Rows, RHSCols > operator*( const Mat< T, RHSRows, RHSCols >& rhs ) const
98 {
99 static_assert( Cols == RHSRows, "Matrix dimensions do not match" );
100
102 for ( int i = 0; i < Rows; ++i )
103 {
104 for ( int j = 0; j < RHSCols; ++j )
105 {
106 result( i, j ) = T( 0 );
107 for ( int k = 0; k < Cols; ++k )
108 {
109 result( i, j ) += data[i][k] * rhs( k, j );
110 }
111 }
112 }
113 return result;
114 }
115
116 // Matrix-vector multiplication
117 KOKKOS_INLINE_FUNCTION
119 {
120 Vec< T, Rows > result;
121 for ( int i = 0; i < Rows; ++i )
122 {
123 result( i ) = 0;
124 for ( int j = 0; j < Cols; ++j )
125 {
126 result( i ) += data[i][j] * vec( j );
127 }
128 }
129 return result;
130 }
131
132 KOKKOS_INLINE_FUNCTION
133 Mat operator*( const T& scalar ) const
134 {
135 Mat result;
136 for ( int i = 0; i < Rows; ++i )
137 {
138 for ( int j = 0; j < Cols; ++j )
139 {
140 result( i, j ) = data[i][j] * scalar;
141 }
142 }
143 return result;
144 }
145
146 KOKKOS_INLINE_FUNCTION
147 Mat& operator+=( const Mat& mat )
148 {
149 for ( int i = 0; i < Rows; ++i )
150 {
151 for ( int j = 0; j < Cols; ++j )
152 {
153 data[i][j] += mat.data[i][j];
154 }
155 }
156 return *this;
157 }
158
159 KOKKOS_INLINE_FUNCTION
160 Mat operator+( const Mat& mat )
161 {
162 Mat result;
163 for ( int i = 0; i < Rows; ++i )
164 {
165 for ( int j = 0; j < Cols; ++j )
166 {
167 result.data[i][j] = data[i][j] + mat.data[i][j];
168 }
169 }
170 return result;
171 }
172
173 KOKKOS_INLINE_FUNCTION
174 Mat& operator=( const Mat& mat )
175 {
176 for ( int i = 0; i < Rows; ++i )
177 {
178 for ( int j = 0; j < Cols; ++j )
179 {
180 data[i][j] = mat.data[i][j];
181 }
182 }
183 return *this;
184 }
185
186 KOKKOS_INLINE_FUNCTION
188 {
190 for ( int i = 0; i < Rows; ++i )
191 {
192 for ( int j = 0; j < Cols; ++j )
193 {
194 result( j, i ) = data[i][j];
195 }
196 }
197 return result;
198 }
199
200 KOKKOS_INLINE_FUNCTION
201 void fill( const T value )
202 {
203 for ( int i = 0; i < Rows; ++i )
204 {
205 for ( int j = 0; j < Cols; ++j )
206 {
207 data[i][j] = value;
208 }
209 }
210 }
211
212 KOKKOS_INLINE_FUNCTION
213 Mat& hadamard_product( const Mat& mat )
214 {
215 for ( int i = 0; i < Rows; ++i )
216 {
217 for ( int j = 0; j < Cols; ++j )
218 {
219 data[i][j] *= mat.data[i][j];
220 }
221 }
222 return *this;
223 }
224
225 KOKKOS_INLINE_FUNCTION
226 T double_contract( const Mat& mat )
227 {
228 T v = 0.0;
229 for ( int i = 0; i < Rows; ++i )
230 {
231 for ( int j = 0; j < Cols; ++j )
232 {
233 v += data[i][j] * mat.data[i][j];
234 }
235 }
236 return v;
237 }
238
239 KOKKOS_INLINE_FUNCTION
240 constexpr T det() const
241 {
242 if constexpr ( Rows == 2 && Cols == 2 )
243 {
244 return data[0][0] * data[1][1] - data[0][1] * data[1][0];
245 }
246 else if constexpr ( Rows == 3 && Cols == 3 )
247 {
248 return data[0][0] * ( data[1][1] * data[2][2] - data[1][2] * data[2][1] ) -
249 data[0][1] * ( data[1][0] * data[2][2] - data[1][2] * data[2][0] ) +
250 data[0][2] * ( data[1][0] * data[2][1] - data[1][1] * data[2][0] );
251 }
252 else
253 {
254 static_assert( Rows == -1, "det() only implemented for 2x2 and 3x3 matrices" );
255 }
256 }
257
258 KOKKOS_INLINE_FUNCTION
259 constexpr Mat inv() const
260 {
261 if constexpr ( Rows == 2 && Cols == 2 )
262 {
263 const T d = det();
264 if ( d == T( 0 ) )
265 Kokkos::abort( "Singular matrix" );
266 const T invDet = T( 1 ) / d;
267 return { { { data[1][1] * invDet, -data[0][1] * invDet }, { -data[1][0] * invDet, data[0][0] * invDet } } };
268 }
269 else if constexpr ( Rows == 3 && Cols == 3 )
270 {
271 const T d = det();
272#ifndef NDEBUG
273 if ( d == T( 0 ) )
274 Kokkos::abort( "Singular matrix" );
275#endif
276 const T id = T( 1 ) / d;
277
279 r( 0, 0 ) = ( data[1][1] * data[2][2] - data[1][2] * data[2][1] ) * id;
280 r( 0, 1 ) = -( data[0][1] * data[2][2] - data[0][2] * data[2][1] ) * id;
281 r( 0, 2 ) = ( data[0][1] * data[1][2] - data[0][2] * data[1][1] ) * id;
282
283 r( 1, 0 ) = -( data[1][0] * data[2][2] - data[1][2] * data[2][0] ) * id;
284 r( 1, 1 ) = ( data[0][0] * data[2][2] - data[0][2] * data[2][0] ) * id;
285 r( 1, 2 ) = -( data[0][0] * data[1][2] - data[0][2] * data[1][0] ) * id;
286
287 r( 2, 0 ) = ( data[1][0] * data[2][1] - data[1][1] * data[2][0] ) * id;
288 r( 2, 1 ) = -( data[0][0] * data[2][1] - data[0][1] * data[2][0] ) * id;
289 r( 2, 2 ) = ( data[0][0] * data[1][1] - data[0][1] * data[1][0] ) * id;
290 return r;
291 }
292 else
293 {
294 static_assert( Rows == -1, "inv() only implemented for 2x2 and 3x3 matrices" );
295 }
296 }
297
298 KOKKOS_INLINE_FUNCTION
299 constexpr Mat inv_transposed() const
300 {
301 if constexpr ( Rows == 2 && Cols == 2 )
302 {
303 const T d = det();
304 if ( d == T( 0 ) )
305 Kokkos::abort( "Singular matrix" );
306 const T invDet = T( 1 ) / d;
307 return { { { data[1][1] * invDet, -data[0][1] * invDet }, { -data[1][0] * invDet, data[0][0] * invDet } } };
308 }
309 else if constexpr ( Rows == 3 && Cols == 3 )
310 {
311 const T d = det();
312#ifndef NDEBUG
313 if ( d == T( 0 ) )
314 Kokkos::abort( "Singular matrix" );
315#endif
316 const T id = T( 1 ) / d;
317
319 r( 0, 0 ) = ( data[1][1] * data[2][2] - data[1][2] * data[2][1] ) * id;
320 r( 0, 1 ) = -( data[1][0] * data[2][2] - data[1][2] * data[2][0] ) * id;
321 r( 0, 2 ) = ( data[1][0] * data[2][1] - data[1][1] * data[2][0] ) * id;
322
323 r( 1, 0 ) = -( data[0][1] * data[2][2] - data[0][2] * data[2][1] ) * id;
324 r( 1, 1 ) = ( data[0][0] * data[2][2] - data[0][2] * data[2][0] ) * id;
325 r( 1, 2 ) = -( data[0][0] * data[2][1] - data[0][1] * data[2][0] ) * id;
326
327 r( 2, 0 ) = ( data[0][1] * data[1][2] - data[0][2] * data[1][1] ) * id;
328 r( 2, 1 ) = -( data[0][0] * data[1][2] - data[0][2] * data[1][0] ) * id;
329 r( 2, 2 ) = ( data[0][0] * data[1][1] - data[0][1] * data[1][0] ) * id;
330 return r;
331 }
332 else
333 {
334 static_assert( Rows == -1, "inv() only implemented for 2x2 and 3x3 matrices" );
335 }
336 }
337
338 KOKKOS_INLINE_FUNCTION
339 constexpr Mat inv_transposed( const T& det ) const
340 {
341 if constexpr ( Rows == 2 && Cols == 2 )
342 {
343 if ( det == T( 0 ) )
344 Kokkos::abort( "Singular matrix" );
345 const T invDet = T( 1 ) / det;
346 return { { { data[1][1] * invDet, -data[0][1] * invDet }, { -data[1][0] * invDet, data[0][0] * invDet } } };
347 }
348 else if constexpr ( Rows == 3 && Cols == 3 )
349 {
350#ifndef NDEBUG
351 if ( det == T( 0 ) )
352 Kokkos::abort( "Singular matrix" );
353#endif
354 const T id = T( 1 ) / det;
355
357 r( 0, 0 ) = ( data[1][1] * data[2][2] - data[1][2] * data[2][1] ) * id;
358 r( 0, 1 ) = -( data[1][0] * data[2][2] - data[1][2] * data[2][0] ) * id;
359 r( 0, 2 ) = ( data[1][0] * data[2][1] - data[1][1] * data[2][0] ) * id;
360
361 r( 1, 0 ) = -( data[0][1] * data[2][2] - data[0][2] * data[2][1] ) * id;
362 r( 1, 1 ) = ( data[0][0] * data[2][2] - data[0][2] * data[2][0] ) * id;
363 r( 1, 2 ) = -( data[0][0] * data[2][1] - data[0][1] * data[2][0] ) * id;
364
365 r( 2, 0 ) = ( data[0][1] * data[1][2] - data[0][2] * data[1][1] ) * id;
366 r( 2, 1 ) = -( data[0][0] * data[1][2] - data[0][2] * data[1][0] ) * id;
367 r( 2, 2 ) = ( data[0][0] * data[1][1] - data[0][1] * data[1][0] ) * id;
368 return r;
369 }
370 else
371 {
372 static_assert( Rows == -1, "inv() only implemented for 2x2 and 3x3 matrices" );
373 }
374 }
375
376 KOKKOS_INLINE_FUNCTION
377 constexpr Mat diagonal() const
378 {
379 Mat result;
380 for ( int i = 0; i < Rows; ++i )
381 {
382 for ( int j = 0; j < Cols; ++j )
383 {
384 result( i, j ) = ( i == j ) ? data[i][j] : T( 0 );
385 }
386 }
387 return result;
388 }
389};
390
391template < typename T, int Rows, int Cols >
392std::ostream& operator<<( std::ostream& os, const Mat< T, Rows, Cols >& A )
393{
394 for ( int i = 0; i < A.rows; ++i )
395 {
396 for ( int j = 0; j < A.cols; ++j )
397 {
398 os << A( i, j ) << " ";
399 }
400 os << '\n';
401 }
402 return os;
403}
404
405} // namespace terra::dense
Definition mat.hpp:6
std::ostream & operator<<(std::ostream &os, const Mat< T, Rows, Cols > &A)
Definition mat.hpp:392
Definition mat.hpp:10
Mat< T, Rows, RHSCols > operator*(const Mat< T, RHSRows, RHSCols > &rhs) const
Definition mat.hpp:97
Mat & operator=(const Mat &mat)
Definition mat.hpp:174
constexpr Mat inv_transposed() const
Definition mat.hpp:299
constexpr T det() const
Definition mat.hpp:240
void fill(const T value)
Definition mat.hpp:201
static constexpr Mat from_col_vecs(const Vec< T, Rows > &col0, const Vec< T, Rows > &col1, const Vec< T, Rows > &col2)
Definition mat.hpp:49
T & operator()(int i, int j)
Definition mat.hpp:90
constexpr Mat diagonal() const
Definition mat.hpp:377
static constexpr int cols
Definition mat.hpp:13
T data[Rows][Cols]
Definition mat.hpp:11
static constexpr Mat diagonal_from_vec(const Vec< T, Rows > &diagonal)
Definition mat.hpp:79
static constexpr Mat from_col_vecs(const Vec< T, Rows > &col0, const Vec< T, Rows > &col1)
Definition mat.hpp:36
Mat operator+(const Mat &mat)
Definition mat.hpp:160
Mat & operator+=(const Mat &mat)
Definition mat.hpp:147
constexpr Mat inv_transposed(const T &det) const
Definition mat.hpp:339
Mat operator*(const T &scalar) const
Definition mat.hpp:133
constexpr Mat< T, Cols, Rows > transposed() const
Definition mat.hpp:187
Vec< T, Rows > operator*(const Vec< T, Cols > &vec) const
Definition mat.hpp:118
static constexpr Mat from_single_col_vec(const Vec< T, Cols > &col, const int d)
Definition mat.hpp:66
Mat & hadamard_product(const Mat &mat)
Definition mat.hpp:213
const T & operator()(int i, int j) const
Definition mat.hpp:93
static constexpr int rows
Definition mat.hpp:12
constexpr Mat inv() const
Definition mat.hpp:259
T double_contract(const Mat &mat)
Definition mat.hpp:226
static constexpr Mat from_row_vecs(const Vec< T, Cols > &row0, const Vec< T, Cols > &row1, const Vec< T, Cols > &row2)
Definition mat.hpp:19
Definition vec.hpp:9