Loading...
Searching...
No Matches
gradient.hpp
Go to the documentation of this file.
1
2#pragma once
3
4#include "../../quadrature/quadrature.hpp"
7#include "dense/vec.hpp"
11#include "linalg/operator.hpp"
13#include "linalg/vector.hpp"
14#include "linalg/vector_q1.hpp"
15#include "util/timer.hpp"
16
18
29template < typename ScalarT >
31{
32 public:
35 using ScalarType = ScalarT;
36
37 private:
39 grid::shell::DistributedDomain domain_coarse_;
40
44 BoundaryConditions bcs_;
45
46 linalg::OperatorApplyMode operator_apply_mode_;
47 linalg::OperatorCommunicationMode operator_communication_mode_;
48
51
54
55 public:
57 const grid::shell::DistributedDomain& domain_fine,
58 const grid::shell::DistributedDomain& domain_coarse,
59 const grid::Grid3DDataVec< ScalarT, 3 >& grid_fine,
60 const grid::Grid2DDataScalar< ScalarT >& radii_fine,
62 BoundaryConditions bcs,
64 linalg::OperatorCommunicationMode operator_communication_mode =
66 : domain_fine_( domain_fine )
67 , domain_coarse_( domain_coarse )
68 , grid_fine_( grid_fine )
69 , radii_( radii_fine )
70 , boundary_mask_fine_( boundary_mask_fine )
71 , operator_apply_mode_( operator_apply_mode )
72 , operator_communication_mode_( operator_communication_mode )
73 , recv_buffers_( domain_fine )
74 , comm_plan_( domain_fine )
75 {
76 bcs_[0] = bcs[0];
77 bcs_[1] = bcs[1];
78 }
79
81 const linalg::OperatorApplyMode operator_apply_mode,
82 const linalg::OperatorCommunicationMode operator_communication_mode )
83 {
84 operator_apply_mode_ = operator_apply_mode;
85 operator_communication_mode_ = operator_communication_mode;
86 }
87
88 void apply_impl( const SrcVectorType& src, DstVectorType& dst )
89 {
90 util::Timer timer_apply( "gradient_apply" );
91
92 if ( operator_apply_mode_ == linalg::OperatorApplyMode::Replace )
93 {
94 assign( dst, 0 );
95 }
96
97 src_ = src.grid_data();
98 dst_ = dst.grid_data();
99
100 util::Timer timer_kernel( "gradient_kernel" );
101 Kokkos::parallel_for( "matvec", grid::shell::local_domain_md_range_policy_cells( domain_fine_ ), *this );
102 Kokkos::fence();
103 timer_kernel.stop();
104
105 if ( operator_communication_mode_ == linalg::OperatorCommunicationMode::CommunicateAdditively )
106 {
107 util::Timer timer_comm( "gradient_comm" );
108 terra::communication::shell::send_recv_with_plan( comm_plan_, dst_, recv_buffers_ );
109 }
110 }
111
112 KOKKOS_INLINE_FUNCTION void
113 operator()( const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell ) const
114 {
115 // Gather surface points for each wedge.
117 wedge_surface_physical_coords( wedge_phy_surf, grid_fine_, local_subdomain_id, x_cell, y_cell );
118
119 // Gather wedge radii.
120 const ScalarT r_1 = radii_( local_subdomain_id, r_cell );
121 const ScalarT r_2 = radii_( local_subdomain_id, r_cell + 1 );
122
123 // Quadrature points.
124 constexpr auto num_quad_points = quadrature::quad_felippa_3x2_num_quad_points;
125
126 dense::Vec< ScalarT, 3 > quad_points[num_quad_points];
127 ScalarT quad_weights[num_quad_points];
128
131
132 const int fine_radial_wedge_index = r_cell % 2;
133
134 // Compute the local element matrix.
136
137 for ( int q = 0; q < num_quad_points; q++ )
138 {
139 for ( int wedge = 0; wedge < num_wedges_per_hex_cell; wedge++ )
140 {
141 const int fine_lateral_wedge_index = fine_lateral_wedge_idx( x_cell, y_cell, wedge );
142
143 const auto J = jac( wedge_phy_surf[wedge], r_1, r_2, quad_points[q] );
144 const auto det = Kokkos::abs( J.det() );
145 const auto J_inv_transposed = J.inv().transposed();
146
147 for ( int i = 0; i < num_nodes_per_wedge; i++ )
148 {
149 const auto grad_i = grad_shape( i, quad_points[q] );
150
151 for ( int j = 0; j < num_nodes_per_wedge; j++ )
152 {
153 const auto shape_j =
154 shape_coarse( j, fine_radial_wedge_index, fine_lateral_wedge_index, quad_points[q] );
155
156 for ( int d = 0; d < 3; d++ )
157 {
158 A[wedge]( d * 6 + i, j ) +=
159 quad_weights[q] * ( -( ( J_inv_transposed * grad_i )(d) *shape_j ) * det );
160 }
161 }
162 }
163 }
164 }
165
167 extract_local_wedge_scalar_coefficients( src, local_subdomain_id, x_cell / 2, y_cell / 2, r_cell / 2, src_ );
168
169 // Boundary treatment
170 bool at_cmb = util::has_flag( boundary_mask_fine_( local_subdomain_id, x_cell, y_cell, r_cell ), CMB );
171 bool at_surface =
172 util::has_flag( boundary_mask_fine_( local_subdomain_id, x_cell, y_cell, r_cell + 1 ), SURFACE );
173
174 dense::Mat< ScalarT, 18, 6 > boundary_mask;
175 boundary_mask.fill( 1.0 );
176
178 // flag to later not go through the hustle of checking the bcs
179 bool freeslip_reorder = false;
180
181 if ( at_cmb || at_surface )
182 {
183 // Inner boundary (CMB).
184 ShellBoundaryFlag sbf = at_cmb ? CMB : SURFACE;
185 BoundaryConditionFlag bcf = get_boundary_condition_flag( bcs_, sbf );
186
187 if ( bcf == DIRICHLET )
188 {
189 for ( int dimi = 0; dimi < 3; ++dimi )
190 {
191 for ( int i = 0; i < num_nodes_per_wedge; i++ )
192 {
193 for ( int j = 0; j < num_nodes_per_wedge; j++ )
194 {
195 if ( ( at_cmb && ( i < 3 ) ) || ( at_surface && ( i >= 3 ) ) )
196 {
197 boundary_mask( dimi * num_nodes_per_wedge + i, j ) = 0.0;
198 }
199 }
200 }
201 }
202 }
203 else if ( bcf == FREESLIP )
204 {
205 freeslip_reorder = true;
207
208 // reorder source dofs for nodes instead of velocity dims in src vector and local matrix
209 for ( int wedge = 0; wedge < 2; ++wedge )
210 {
211 for ( int node_idxi = 0; node_idxi < num_nodes_per_wedge; node_idxi++ )
212 {
213 for ( int dimi = 0; dimi < 3; ++dimi )
214 {
215 for ( int node_idxj = 0; node_idxj < num_nodes_per_wedge; node_idxj++ )
216 {
217 A_tmp[wedge]( node_idxi * 3 + dimi, node_idxj ) =
218 A[wedge]( node_idxi + dimi * num_nodes_per_wedge, node_idxj );
219 }
220 }
221 }
222 }
223
224 // assemble rotation matrices for boundary nodes
225 // e.g. if we are at CMB, we need to rotate DoFs 0, 1, 2 of each wedge
226 // at SURFACE, we need to rotate DoFs 3, 4, 5
227
228 constexpr int layer_hex_offset_x[2][3] = { { 0, 1, 0 }, { 1, 0, 1 } };
229 constexpr int layer_hex_offset_y[2][3] = { { 0, 0, 1 }, { 1, 1, 0 } };
230
231 for ( int wedge = 0; wedge < 2; ++wedge )
232 {
233 // make rotation matrix unity
234 for ( int i = 0; i < 18; ++i )
235 {
236 R[wedge]( i, i ) = 1.0;
237 }
238
239 for ( int boundary_node_idx = 0; boundary_node_idx < 3; boundary_node_idx++ )
240 {
241 // compute normal
243 local_subdomain_id,
244 x_cell + layer_hex_offset_x[wedge][boundary_node_idx],
245 y_cell + layer_hex_offset_y[wedge][boundary_node_idx],
246 r_cell + ( at_cmb ? 0 : 1 ),
247 grid_fine_,
248 radii_ );
249
250 // compute rotation matrix for DoFs on current node
251 auto R_i = trafo_mat_cartesian_to_normal_tangential( normal );
252
253 // insert into wedge-local rotation matrix
254 int offset_in_R = at_cmb ? 0 : 9;
255 for ( int dimi = 0; dimi < 3; ++dimi )
256 {
257 for ( int dimj = 0; dimj < 3; ++dimj )
258 {
259 R[wedge](
260 offset_in_R + boundary_node_idx * 3 + dimi,
261 offset_in_R + boundary_node_idx * 3 + dimj ) = R_i( dimi, dimj );
262 }
263 }
264 }
265
266 // transform local matrix to rotated/ normal-tangential space: pre/post multiply with rotation matrices
267 // TODO transpose this way?
268 A[wedge] = R[wedge] * A_tmp[wedge];
269
270 // eliminate normal components: Dirichlet on the normal-tangential system
271 int node_start = at_surface ? 3 : 0;
272 int node_end = at_surface ? 6 : 3;
273
274 for ( int node_idx = node_start; node_idx < node_end; node_idx++ )
275 {
276 int idx = node_idx * 3;
277 for ( int k = 0; k < 6; ++k )
278 {
279 boundary_mask( idx, k ) = 0.0;
280 }
281 }
282 }
283 }
284 else if ( bcf == NEUMANN ) {}
285 }
286
287 // apply boundary mask
288 for ( int wedge = 0; wedge < num_wedges_per_hex_cell; wedge++ )
289 {
290 A[wedge].hadamard_product( boundary_mask );
291 }
292
294
295 dst[0] = A[0] * src[0];
296 dst[1] = A[1] * src[1];
297
298 if ( freeslip_reorder )
299 {
300 // transform dst back from nt space
302 dst_tmp[0] = R[0].transposed() * dst[0];
303 dst_tmp[1] = R[1].transposed() * dst[1];
304 for ( int i = 0; i < 18; ++i )
305 {
306 dst[0]( i ) = dst_tmp[0]( i );
307 dst[1]( i ) = dst_tmp[1]( i );
308 }
309
310 // reorder to dimensionwise ordering
313 }
314
315 for ( int d = 0; d < 3; d++ )
316 {
318 dst_d[0] = dst[0].template slice< 6 >( d * 6 );
319 dst_d[1] = dst[1].template slice< 6 >( d * 6 );
320
322 dst_, local_subdomain_id, x_cell, y_cell, r_cell, d, dst_d );
323 }
324 }
325};
326
328
329} // namespace terra::fe::wedge::operators::shell
Definition communication_plan.hpp:33
ScalarT ScalarType
Definition gradient.hpp:35
Gradient(const grid::shell::DistributedDomain &domain_fine, const grid::shell::DistributedDomain &domain_coarse, const grid::Grid3DDataVec< ScalarT, 3 > &grid_fine, const grid::Grid2DDataScalar< ScalarT > &radii_fine, const grid::Grid4DDataScalar< grid::shell::ShellBoundaryFlag > &boundary_mask_fine, BoundaryConditions bcs, linalg::OperatorApplyMode operator_apply_mode=linalg::OperatorApplyMode::Replace, linalg::OperatorCommunicationMode operator_communication_mode=linalg::OperatorCommunicationMode::CommunicateAdditively)
Definition gradient.hpp:56
void apply_impl(const SrcVectorType &src, DstVectorType &dst)
Definition gradient.hpp:88
void set_operator_apply_and_communication_modes(const linalg::OperatorApplyMode operator_apply_mode, const linalg::OperatorCommunicationMode operator_communication_mode)
Definition gradient.hpp:80
void operator()(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell) const
Definition gradient.hpp:113
Parallel data structure organizing the thick spherical shell metadata for distributed (MPI parallel) ...
Definition spherical_shell.hpp:2518
Q1 scalar finite element vector on a distributed shell grid.
Definition vector_q1.hpp:21
const grid::Grid4DDataScalar< ScalarType > & grid_data() const
Get const reference to grid data.
Definition vector_q1.hpp:139
const grid::Grid4DDataVec< ScalarType, VecDim > & grid_data() const
Get const reference to grid data.
Definition vector_q1.hpp:288
Timer supporting RAII scope or manual stop.
Definition timer.hpp:342
void stop()
Stop the timer and record elapsed time.
Definition timer.hpp:364
Concept for types that behave like linear operators.
Definition operator.hpp:57
void send_recv_with_plan(const ShellBoundaryCommPlan< GridDataType > &plan, const GridDataType &data, SubdomainNeighborhoodSendRecvBuffer< typename GridDataType::value_type, grid::grid_data_vec_dim< GridDataType >() > &recv_buffers, CommunicationReduction reduction=CommunicationReduction::SUM)
Definition communication_plan.hpp:652
Definition boundary_mass.hpp:14
constexpr void quad_felippa_3x2_quad_weights(T(&quad_weights)[quad_felippa_3x2_num_quad_points])
Definition wedge/quadrature/quadrature.hpp:93
constexpr int quad_felippa_3x2_num_quad_points
Definition wedge/quadrature/quadrature.hpp:66
constexpr void quad_felippa_3x2_quad_points(dense::Vec< T, 3 >(&quad_points)[quad_felippa_3x2_num_quad_points])
Definition wedge/quadrature/quadrature.hpp:70
constexpr int num_nodes_per_wedge_surface
Definition kernel_helpers.hpp:6
constexpr int fine_lateral_wedge_idx(const int x_cell_fine, const int y_cell_fine, const int wedge_idx_fine)
Returns the lateral wedge index with respect to a coarse grid wedge from the fine wedge indices.
Definition kernel_helpers.hpp:601
void wedge_surface_physical_coords(dense::Vec< T, 3 >(&wedge_surf_phy_coords)[num_wedges_per_hex_cell][num_nodes_per_wedge_surface], const grid::Grid3DDataVec< T, 3 > &lateral_grid, const int local_subdomain_id, const int x_cell, const int y_cell)
Extracts the (unit sphere) surface vertex coords of the two wedges of a hex cell.
Definition kernel_helpers.hpp:26
constexpr void reorder_local_dofs(const DoFOrdering doo_from, const DoFOrdering doo_to, dense::Vec< ScalarT, 18 > &dofs)
Definition kernel_helpers.hpp:619
void atomically_add_local_wedge_vector_coefficients(const grid::Grid4DDataVec< T, VecDim > &global_coefficients, const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int d, const dense::Vec< T, 6 > local_coefficients[2])
Performs an atomic add of the two local wedge coefficient vectors of a hex cell into the global coeff...
Definition kernel_helpers.hpp:465
constexpr T shape_coarse(const int coarse_node_idx, const int fine_radial_wedge_idx, const int fine_lateral_wedge_idx, const T xi_fine, const T eta_fine, const T zeta_fine)
Definition integrands.hpp:373
constexpr int num_wedges_per_hex_cell
Definition kernel_helpers.hpp:5
void extract_local_wedge_scalar_coefficients(dense::Vec< T, 6 >(&local_coefficients)[2], const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const grid::Grid4DDataScalar< T > &global_coefficients)
Extracts the local vector coefficients for the two wedges of a hex cell from the global coefficient v...
Definition kernel_helpers.hpp:306
constexpr int num_nodes_per_wedge
Definition kernel_helpers.hpp:7
constexpr dense::Vec< T, 3 > grad_shape(const int node_idx, const T xi, const T eta, const T zeta)
Gradient of the full shape function:
Definition integrands.hpp:228
constexpr dense::Mat< T, 3, 3 > jac(const dense::Vec< T, 3 > &p1_phy, const dense::Vec< T, 3 > &p2_phy, const dense::Vec< T, 3 > &p3_phy, const T r_1, const T r_2, const T xi, const T eta, const T zeta)
Definition integrands.hpp:657
dense::Vec< typename CoordsShellType::value_type, 3 > coords(const int subdomain, const int x, const int y, const int r, const CoordsShellType &coords_shell, const CoordsRadiiType &coords_radii)
Definition spherical_shell.hpp:2871
BoundaryConditionMapping[2] BoundaryConditions
Definition shell/bit_masks.hpp:37
ShellBoundaryFlag
FlagLike that indicates boundary types for the thick spherical shell.
Definition shell/bit_masks.hpp:12
Kokkos::MDRangePolicy< Kokkos::Rank< 4 > > local_domain_md_range_policy_cells(const DistributedDomain &distributed_domain)
Definition spherical_shell.hpp:2739
BoundaryConditionFlag get_boundary_condition_flag(const BoundaryConditions bcs, ShellBoundaryFlag sbf)
Retrieve the boundary condition flag that is associated with a location in the shell e....
Definition shell/bit_masks.hpp:42
BoundaryConditionFlag
FlagLike that indicates the type of boundary condition
Definition shell/bit_masks.hpp:25
Kokkos::View< ScalarType ***[VecDim], Layout > Grid3DDataVec
Definition grid_types.hpp:42
Kokkos::View< ScalarType ****, Layout > Grid4DDataScalar
Definition grid_types.hpp:27
Kokkos::View< ScalarType **, Layout > Grid2DDataScalar
Definition grid_types.hpp:21
dense::Mat< ScalarType, 3, 3 > trafo_mat_cartesian_to_normal_tangential(const dense::Vec< ScalarType, 3 > &n_input)
Constructs a robust orthonormal transformation matrix from Cartesian to (normal–tangential–tangential...
Definition local_basis_trafo_normal_tangential.hpp:36
OperatorApplyMode
Modes for applying an operator to a vector.
Definition operator.hpp:30
@ Replace
Overwrite the destination vector.
OperatorCommunicationMode
Modes for communication during operator application.
Definition operator.hpp:40
@ CommunicateAdditively
Communicate and add results.
constexpr bool has_flag(E mask_value, E flag) noexcept
Checks if a bitmask value contains a specific flag.
Definition bit_masking.hpp:43
Definition mat.hpp:10
void fill(const T value)
Definition mat.hpp:201
constexpr Mat< T, Cols, Rows > transposed() const
Definition mat.hpp:187
Mat & hadamard_product(const Mat &mat)
Definition mat.hpp:213