Loading...
Searching...
No Matches
divergence_kerngen.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file divergence_kerngen.hpp
4/// @brief Team-based matrix-free Divergence operator for the spherical shell.
5///
6/// This is a performance-oriented variant of `Divergence`
7/// (`fe/wedge/operators/shell/divergence.hpp`) that transfers the optimisation
8/// techniques used in `EpsilonDivDivKerngen`:
9/// - `Kokkos::TeamPolicy` with backend-aware tile sizing.
10/// - Per-team shared-memory staging of coords, radii and velocity src.
11/// - Host-side BC-aware kernel path dispatch (Slow / FastDirichletNeumann /
12/// FastFreeslip) with no in-kernel branching on the path.
13/// - `LaunchBounds<128, 5>` for occupancy tuning on CUDA.
14/// - `ShellBoundaryCommPlan`-based halo exchange.
15///
16/// The kernel path _math_ is intentionally kept identical to the original
17/// `Divergence`: every cell assembles the same local 6x18 element matrix
18/// (`dense::Mat< ScalarT, 6, 18 > A[2]`), applies the same boundary mask /
19/// freeslip rotation, and scatters through the same
20/// `atomically_add_local_wedge_scalar_coefficients` helper onto the coarse
21/// pressure grid `(x/2, y/2, r/2)`. The only difference in the fast paths is
22/// that per-cell input data is read from team shared memory rather than from
23/// global memory. This gives a safe, first-pass optimisation that keeps
24/// correctness provable against the existing operator.
25
26#include "../../quadrature/quadrature.hpp"
29#include "dense/vec.hpp"
33#include "linalg/operator.hpp"
35#include "linalg/vector.hpp"
36#include "linalg/vector_q1.hpp"
37#include "util/timer.hpp"
38
40
51
52template < typename ScalarT >
54{
55 public:
58 using ScalarType = ScalarT;
59 using Team = Kokkos::TeamPolicy<>::member_type;
60
61 enum class KernelPath
62 {
63 Slow,
66 };
67
68 private:
70 grid::shell::DistributedDomain domain_coarse_;
71
75
76 BoundaryConditions bcs_;
77
78 linalg::OperatorApplyMode operator_apply_mode_;
79 linalg::OperatorCommunicationMode operator_communication_mode_;
80
83
84 // Captured by kernels
87
88 // Tiling (mirror of EpsilonDivDivKerngen)
89 int local_subdomains_;
90 int hex_lat_; // fine cells per side
91 int hex_rad_; // fine cells radially
92 int lat_tile_;
93 int r_tile_;
94 int r_passes_;
95 int r_tile_block_;
96 int lat_tiles_;
97 int r_tiles_;
98 int team_size_;
99 int blocks_;
100
102
103 /// Host-side path selection, updated whenever BCs change.
104 void update_kernel_path_flag_host_only()
105 {
106 // Serial backend: keep the slow path (fast paths assume a real team of cooperative threads).
107 if constexpr ( std::is_same_v< Kokkos::DefaultExecutionSpace, Kokkos::Serial > )
108 {
109 kernel_path_ = KernelPath::Slow;
110 return;
111 }
112
113 const BoundaryConditionFlag cmb_bc = get_boundary_condition_flag( bcs_, CMB );
114 const BoundaryConditionFlag surface_bc = get_boundary_condition_flag( bcs_, SURFACE );
115 const bool has_freeslip = ( cmb_bc == FREESLIP ) || ( surface_bc == FREESLIP );
116
117 kernel_path_ = has_freeslip ? KernelPath::FastFreeslip : KernelPath::FastDirichletNeumann;
118 }
119
120 public:
122 const grid::shell::DistributedDomain& domain_fine,
123 const grid::shell::DistributedDomain& domain_coarse,
124 const grid::Grid3DDataVec< ScalarT, 3 >& grid_fine,
125 const grid::Grid2DDataScalar< ScalarT >& radii_fine,
127 BoundaryConditions bcs,
129 linalg::OperatorCommunicationMode operator_communication_mode =
131 : domain_fine_( domain_fine )
132 , domain_coarse_( domain_coarse )
133 , grid_fine_( grid_fine )
134 , radii_( radii_fine )
135 , boundary_mask_fine_( boundary_mask_fine )
136 , operator_apply_mode_( operator_apply_mode )
137 , operator_communication_mode_( operator_communication_mode )
138 , recv_buffers_( domain_coarse )
139 , comm_plan_( domain_coarse )
140 {
141 bcs_[0] = bcs[0];
142 bcs_[1] = bcs[1];
143
144 const grid::shell::DomainInfo& domain_info = domain_fine_.domain_info();
145 local_subdomains_ = domain_fine_.subdomains().size();
146 hex_lat_ = domain_info.subdomain_num_nodes_per_side_laterally() - 1;
147 hex_rad_ = domain_info.subdomain_num_nodes_radially() - 1;
148
149 if constexpr ( std::is_same_v< Kokkos::DefaultExecutionSpace, Kokkos::Serial > )
150 {
151 lat_tile_ = 1; r_tile_ = 1; r_passes_ = 1;
152 }
153#ifdef KOKKOS_ENABLE_OPENMP
154 else if constexpr ( std::is_same_v< Kokkos::DefaultExecutionSpace, Kokkos::OpenMP > )
155 {
156 const int max_team = std::min( Kokkos::OpenMP().concurrency(),
157 static_cast< int >( Kokkos::Impl::HostThreadTeamData::max_team_members ) );
158 if ( max_team >= 64 ) { lat_tile_ = 4; r_tile_ = 4; r_passes_ = 4; }
159 else if ( max_team >= 16 ) { lat_tile_ = 4; r_tile_ = 1; r_passes_ = 16; }
160 else { lat_tile_ = 1; r_tile_ = 1; r_passes_ = 1; }
161 }
162#endif
163 else
164 {
165 lat_tile_ = 4; r_tile_ = 8; r_passes_ = 2;
166 }
167 r_tile_block_ = r_tile_ * r_passes_;
168 lat_tiles_ = ( hex_lat_ + lat_tile_ - 1 ) / lat_tile_;
169 r_tiles_ = ( hex_rad_ + r_tile_block_ - 1 ) / r_tile_block_;
170 team_size_ = lat_tile_ * lat_tile_ * r_tile_;
171 blocks_ = local_subdomains_ * lat_tiles_ * lat_tiles_ * r_tiles_;
172
173 update_kernel_path_flag_host_only();
174
175 util::logroot << "[DivergenceKerngen] tile=(" << lat_tile_ << "," << lat_tile_ << "," << r_tile_
176 << "), r_passes=" << r_passes_ << ", team=" << team_size_ << ", blocks=" << blocks_
177 << ", path=" << path_name() << std::endl;
178 }
179
180 const char* path_name() const
181 {
182 switch ( kernel_path_ )
183 {
184 case KernelPath::Slow: return "slow";
185 case KernelPath::FastFreeslip: return "fast-freeslip";
186 default: return "fast-dirichlet-neumann";
187 }
188 }
189
190 KernelPath kernel_path() const { return kernel_path_; }
191
192 /// Force the slow path. Useful for validation/testing the fast paths against a reference.
193 void force_slow_path() { kernel_path_ = KernelPath::Slow; }
194
196 const linalg::OperatorApplyMode operator_apply_mode,
197 const linalg::OperatorCommunicationMode operator_communication_mode )
198 {
199 operator_apply_mode_ = operator_apply_mode;
200 operator_communication_mode_ = operator_communication_mode;
201 }
202
203 // -------------------------------------------------------------------------
204 // Apply
205 // -------------------------------------------------------------------------
206 void apply_impl( const SrcVectorType& src, DstVectorType& dst )
207 {
208 util::Timer timer_apply( "divergence_apply" );
209
210 if ( operator_apply_mode_ == linalg::OperatorApplyMode::Replace )
211 {
212 assign( dst, 0 );
213 }
214
215 src_ = src.grid_data();
216 dst_ = dst.grid_data();
217
218 util::Timer timer_kernel( "divergence_kernel" );
219 Kokkos::TeamPolicy<> policy( blocks_, team_size_ );
220 if ( kernel_path_ != KernelPath::Slow )
221 {
222 policy.set_scratch_size( 0, Kokkos::PerTeam( team_shmem_size( team_size_ ) ) );
223 }
224
225 if ( kernel_path_ == KernelPath::Slow )
226 {
227 Kokkos::parallel_for(
228 "divergence_apply_kernel_slow", policy, KOKKOS_CLASS_LAMBDA( const Team& team ) {
229 this->run_team_slow( team );
230 } );
231 }
232 else if ( kernel_path_ == KernelPath::FastFreeslip )
233 {
234 Kokkos::parallel_for(
235 "divergence_apply_kernel_fast_fs", policy, KOKKOS_CLASS_LAMBDA( const Team& team ) {
236 this->template run_team_fast< /*Freeslip=*/true >( team );
237 } );
238 }
239 else
240 {
241 Kokkos::TeamPolicy< Kokkos::LaunchBounds< 128, 5 > > dn_policy( blocks_, team_size_ );
242 dn_policy.set_scratch_size( 0, Kokkos::PerTeam( team_shmem_size( team_size_ ) ) );
243 Kokkos::parallel_for(
244 "divergence_apply_kernel_fast_dn", dn_policy, KOKKOS_CLASS_LAMBDA( const Team& team ) {
245 this->template run_team_fast< /*Freeslip=*/false >( team );
246 } );
247 }
248
249 Kokkos::fence();
250 timer_kernel.stop();
251
252 if ( operator_communication_mode_ == linalg::OperatorCommunicationMode::CommunicateAdditively )
253 {
254 util::Timer timer_comm( "divergence_comm" );
255 terra::communication::shell::send_recv_with_plan( comm_plan_, dst_, recv_buffers_ );
256 }
257 }
258
259 // -------------------------------------------------------------------------
260 // Team shmem sizing & decode helpers
261 // -------------------------------------------------------------------------
262 KOKKOS_INLINE_FUNCTION
263 size_t team_shmem_size( const int /*ts*/ ) const
264 {
265 const int nlev = r_tile_block_ + 1;
266 const int n = lat_tile_ + 1;
267 const int nxy = n * n;
268 // coords_sh(nxy,3) + src_sh(nxy,3,nlev) + r_sh(nlev)
269 const size_t nscalars = size_t( nxy ) * 3 + size_t( nxy ) * 3 * nlev + size_t( nlev );
270 return sizeof( ScalarType ) * nscalars;
271 }
272
273 private:
274 KOKKOS_INLINE_FUNCTION
275 void decode_team_indices(
276 const Team& team,
277 int& local_subdomain_id,
278 int& x0,
279 int& y0,
280 int& r0,
281 int& tx,
282 int& ty,
283 int& tr,
284 int& x_cell,
285 int& y_cell,
286 int& r_cell ) const
287 {
288 int tmp = team.league_rank();
289 const int r_tile_id = tmp % r_tiles_;
290 tmp /= r_tiles_;
291 const int lat_y_id = tmp % lat_tiles_;
292 tmp /= lat_tiles_;
293 const int lat_x_id = tmp % lat_tiles_;
294 tmp /= lat_tiles_;
295 local_subdomain_id = tmp;
296
297 x0 = lat_x_id * lat_tile_;
298 y0 = lat_y_id * lat_tile_;
299 r0 = r_tile_id * r_tile_block_;
300
301 const int tid = team.team_rank();
302 tr = tid % r_tile_;
303 tx = ( tid / r_tile_ ) % lat_tile_;
304 ty = tid / ( r_tile_ * lat_tile_ );
305
306 x_cell = x0 + tx;
307 y_cell = y0 + ty;
308 r_cell = r0 + tr;
309 }
310
311 KOKKOS_INLINE_FUNCTION
312 bool has_flag( int s, int x, int y, int r, grid::shell::ShellBoundaryFlag flag ) const
313 {
314 return util::has_flag( boundary_mask_fine_( s, x, y, r ), flag );
315 }
316
317 // -------------------------------------------------------------------------
318 // Slow path: team-policy wrapper that calls the original per-cell kernel.
319 // Kept verbatim (same math) as divergence.hpp so it serves as the reference.
320 // -------------------------------------------------------------------------
321 KOKKOS_INLINE_FUNCTION
322 void run_team_slow( const Team& team ) const
323 {
324 int local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell;
325 decode_team_indices( team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell );
326
327 if ( tr >= r_tile_ )
328 return;
329
330 for ( int pass = 0; pass < r_passes_; ++pass )
331 {
332 const int r_cell_pass = r0 + pass * r_tile_ + tr;
333 if ( r_cell_pass >= hex_rad_ )
334 break;
335 if ( x_cell >= hex_lat_ || y_cell >= hex_lat_ )
336 continue;
337
338 process_cell_from_global( local_subdomain_id, x_cell, y_cell, r_cell_pass );
339 }
340 }
341
342 // Per-cell compute: reads src/coords/radii directly from global memory.
343 // Bit-identical (modulo FMA reordering) to Divergence::operator() in divergence.hpp.
344 KOKKOS_INLINE_FUNCTION
345 void process_cell_from_global( int s, int x_cell, int y_cell, int r_cell ) const
346 {
348 wedge_surface_physical_coords( wedge_phy_surf, grid_fine_, s, x_cell, y_cell );
349
350 const ScalarT r_1 = radii_( s, r_cell );
351 const ScalarT r_2 = radii_( s, r_cell + 1 );
352
353 dense::Vec< ScalarT, 3 > quad_points[quadrature::quad_felippa_1x1_num_quad_points];
357
358 dense::Vec< ScalarT, 18 > src_loc[num_wedges_per_hex_cell];
359 for ( int d = 0; d < 3; d++ )
360 {
361 dense::Vec< ScalarT, 6 > src_d[num_wedges_per_hex_cell];
362 extract_local_wedge_vector_coefficients( src_d, s, x_cell, y_cell, r_cell, d, src_ );
363 for ( int w = 0; w < num_wedges_per_hex_cell; ++w )
364 for ( int i = 0; i < num_nodes_per_wedge; ++i )
365 src_loc[w]( d * 6 + i ) = src_d[w]( i );
366 }
367
368 eval_cell( s, x_cell, y_cell, r_cell, r_1, r_2, wedge_phy_surf, quad_points, quad_weights, src_loc );
369 }
370
371 // Shared core: given per-cell inputs already gathered, assemble local 6x18 A matrices,
372 // apply boundary treatment, matvec, scatter to coarse grid. Used by both slow and fast paths.
373 KOKKOS_INLINE_FUNCTION
374 void eval_cell(
375 const int s,
376 const int x_cell,
377 const int y_cell,
378 const int r_cell,
379 const ScalarT r_1,
380 const ScalarT r_2,
381 const dense::Vec< ScalarT, 3 > ( &wedge_phy_surf )[num_wedges_per_hex_cell][num_nodes_per_wedge_surface],
382 const dense::Vec< ScalarT, 3 > ( &quad_points )[quadrature::quad_felippa_1x1_num_quad_points],
383 const ScalarT ( &quad_weights )[quadrature::quad_felippa_1x1_num_quad_points],
384 const dense::Vec< ScalarT, 18 > ( &src_in )[num_wedges_per_hex_cell] ) const
385 {
386 constexpr int num_quad_points = quadrature::quad_felippa_1x1_num_quad_points;
387 const int fine_radial_wedge_index = r_cell % 2;
388
389 dense::Mat< ScalarT, 6, 18 > A[num_wedges_per_hex_cell] = {};
390
391 for ( int q = 0; q < num_quad_points; q++ )
392 {
393 for ( int wedge = 0; wedge < num_wedges_per_hex_cell; wedge++ )
394 {
395 const int fine_lateral_wedge_index = fine_lateral_wedge_idx( x_cell, y_cell, wedge );
396 const auto J = jac( wedge_phy_surf[wedge], r_1, r_2, quad_points[q] );
397 const auto det = Kokkos::abs( J.det() );
398 const auto J_inv_transposed = J.inv().transposed();
399
400 for ( int i = 0; i < num_nodes_per_wedge; i++ )
401 {
402 const auto shape_i =
403 shape_coarse( i, fine_radial_wedge_index, fine_lateral_wedge_index, quad_points[q] );
404
405 for ( int j = 0; j < num_nodes_per_wedge; j++ )
406 {
407 const auto grad_j = grad_shape( j, quad_points[q] );
408 for ( int d = 0; d < 3; d++ )
409 {
410 A[wedge]( i, d * 6 + j ) +=
411 quad_weights[q] * ( -( J_inv_transposed * grad_j )( d ) * shape_i * det );
412 }
413 }
414 }
415 }
416 }
417
418 bool at_cmb = util::has_flag( boundary_mask_fine_( s, x_cell, y_cell, r_cell ), CMB );
419 bool at_surface = util::has_flag( boundary_mask_fine_( s, x_cell, y_cell, r_cell + 1 ), SURFACE );
420
421 // Locally modify src + A for FreeSlip / Dirichlet.
422 dense::Vec< ScalarT, 18 > src_loc[num_wedges_per_hex_cell];
423 for ( int w = 0; w < num_wedges_per_hex_cell; ++w )
424 src_loc[w] = src_in[w];
425
426 dense::Mat< ScalarT, 6, 18 > boundary_mask;
427 boundary_mask.fill( 1.0 );
428
429 if ( at_cmb || at_surface )
430 {
431 ShellBoundaryFlag sbf = at_cmb ? CMB : SURFACE;
433
434 if ( bcf == DIRICHLET )
435 {
436 for ( int dimj = 0; dimj < 3; ++dimj )
437 for ( int i = 0; i < num_nodes_per_wedge; i++ )
438 for ( int j = 0; j < num_nodes_per_wedge; j++ )
439 if ( ( at_cmb && j < 3 ) || ( at_surface && j >= 3 ) )
440 boundary_mask( i, dimj * num_nodes_per_wedge + j ) = 0.0;
441 }
442 else if ( bcf == FREESLIP )
443 {
444 dense::Mat< ScalarT, 6, 18 > A_tmp[num_wedges_per_hex_cell] = { 0 };
445
446 for ( int wedge = 0; wedge < 2; ++wedge )
447 {
448 for ( int node_idxi = 0; node_idxi < num_nodes_per_wedge; ++node_idxi )
449 for ( int dimj = 0; dimj < 3; ++dimj )
450 for ( int node_idxj = 0; node_idxj < num_nodes_per_wedge; ++node_idxj )
451 A_tmp[wedge]( node_idxi, node_idxj * 3 + dimj ) =
452 A[wedge]( node_idxi, node_idxj + dimj * num_nodes_per_wedge );
454 }
455
456 constexpr int layer_hex_offset_x[2][3] = { { 0, 1, 0 }, { 1, 0, 1 } };
457 constexpr int layer_hex_offset_y[2][3] = { { 0, 0, 1 }, { 1, 1, 0 } };
458
459 dense::Mat< ScalarT, 18, 18 > R[num_wedges_per_hex_cell];
460 for ( int wedge = 0; wedge < 2; ++wedge )
461 {
462 for ( int i = 0; i < 18; ++i )
463 R[wedge]( i, i ) = 1.0;
464
465 for ( int bn = 0; bn < 3; ++bn )
466 {
467 dense::Vec< double, 3 > normal = grid::shell::coords(
468 s,
469 x_cell + layer_hex_offset_x[wedge][bn],
470 y_cell + layer_hex_offset_y[wedge][bn],
471 r_cell + ( at_cmb ? 0 : 1 ),
472 grid_fine_,
473 radii_ );
474
475 auto R_i = trafo_mat_cartesian_to_normal_tangential( normal );
476 const int offset_in_R = at_cmb ? 0 : 9;
477 for ( int dimi = 0; dimi < 3; ++dimi )
478 for ( int dimj = 0; dimj < 3; ++dimj )
479 R[wedge]( offset_in_R + bn * 3 + dimi, offset_in_R + bn * 3 + dimj ) = R_i( dimi, dimj );
480 }
481
482 A[wedge] = A_tmp[wedge] * R[wedge].transposed();
483
484 auto src_tmp = R[wedge] * src_loc[wedge];
485 for ( int i = 0; i < 18; ++i )
486 src_loc[wedge]( i ) = src_tmp( i );
487
488 const int node_start = at_surface ? 3 : 0;
489 const int node_end = at_surface ? 6 : 3;
490 for ( int node_idx = node_start; node_idx < node_end; ++node_idx )
491 {
492 const int idx = node_idx * 3;
493 for ( int k = 0; k < 6; ++k )
494 boundary_mask( k, idx ) = 0.0;
495 }
496 }
497 }
498 else if ( bcf == NEUMANN )
499 {
500 // No mask modification for Neumann (matches legacy Divergence).
501 }
502 }
503
504 for ( int wedge = 0; wedge < num_wedges_per_hex_cell; wedge++ )
505 A[wedge].hadamard_product( boundary_mask );
506
507 dense::Vec< ScalarT, 6 > dst_loc[num_wedges_per_hex_cell];
508 dst_loc[0] = A[0] * src_loc[0];
509 dst_loc[1] = A[1] * src_loc[1];
510
512 dst_, s, x_cell / 2, y_cell / 2, r_cell / 2, dst_loc );
513 }
514
515 // -------------------------------------------------------------------------
516 // Fast paths: team-policy with shmem-cached coords / radii / src velocity.
517 // Templated on Freeslip so the rotation code is dead-eliminated on the DN path.
518 // -------------------------------------------------------------------------
519 template < bool Freeslip >
520 KOKKOS_INLINE_FUNCTION
521 void run_team_fast( const Team& team ) const
522 {
523 int local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell;
524 decode_team_indices( team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell );
525
526 const int nlev = r_tile_block_ + 1;
527 const int n = lat_tile_ + 1;
528 const int nxy = n * n;
529
530 double* shmem =
531 reinterpret_cast< double* >( team.team_shmem().get_shmem( team_shmem_size( team.team_size() ) ) );
532
533 using ScratchCoords =
534 Kokkos::View< double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
535 using ScratchSrc =
536 Kokkos::View< double***, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
537 using ScratchR =
538 Kokkos::View< double*, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
539
540 ScratchCoords coords_sh( shmem, nxy, 3 );
541 shmem += nxy * 3;
542 ScratchSrc src_sh( shmem, nxy, 3, nlev );
543 shmem += nxy * 3 * nlev;
544 ScratchR r_sh( shmem, nlev );
545
546 auto node_id = [&]( int nx, int ny ) -> int { return nx + n * ny; };
547
548 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, nxy ), [&]( int id ) {
549 const int dxn = id % n;
550 const int dyn = id / n;
551 const int xi = x0 + dxn;
552 const int yi = y0 + dyn;
553 if ( xi <= hex_lat_ && yi <= hex_lat_ )
554 {
555 coords_sh( id, 0 ) = grid_fine_( local_subdomain_id, xi, yi, 0 );
556 coords_sh( id, 1 ) = grid_fine_( local_subdomain_id, xi, yi, 1 );
557 coords_sh( id, 2 ) = grid_fine_( local_subdomain_id, xi, yi, 2 );
558 }
559 else
560 {
561 coords_sh( id, 0 ) = coords_sh( id, 1 ) = coords_sh( id, 2 ) = 0.0;
562 }
563 } );
564
565 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, nlev ), [&]( int lvl ) {
566 const int rr = r0 + lvl;
567 r_sh( lvl ) = ( rr <= hex_rad_ ) ? radii_( local_subdomain_id, rr ) : 0.0;
568 } );
569
570 const int total_src = nxy * nlev;
571 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, total_src ), [&]( int t ) {
572 const int node = t / nlev;
573 const int lvl = t - node * nlev;
574 const int dxn = node % n;
575 const int dyn = node / n;
576 const int xi = x0 + dxn;
577 const int yi = y0 + dyn;
578 const int rr = r0 + lvl;
579 if ( xi <= hex_lat_ && yi <= hex_lat_ && rr <= hex_rad_ )
580 {
581 src_sh( node, 0, lvl ) = src_( local_subdomain_id, xi, yi, rr, 0 );
582 src_sh( node, 1, lvl ) = src_( local_subdomain_id, xi, yi, rr, 1 );
583 src_sh( node, 2, lvl ) = src_( local_subdomain_id, xi, yi, rr, 2 );
584 }
585 else
586 {
587 src_sh( node, 0, lvl ) = src_sh( node, 1, lvl ) = src_sh( node, 2, lvl ) = 0.0;
588 }
589 } );
590
591 team.team_barrier();
592
593 if ( tr >= r_tile_ )
594 return;
595 if ( x_cell >= hex_lat_ || y_cell >= hex_lat_ )
596 return;
597
598 // --- Fused arithmetic kernel (same optimisations as EpsilonDivDivKerngen's fast paths) ---
599 //
600 // The divergence operator's 6x18 element matrix collapses to
601 // p_i = -qw * |det J| * shape_coarse_i * (∇·u)_phys
602 // so we never materialise A. Per wedge we compute one Jacobian, one div_u at the quadrature
603 // point (18 FMAs), then scatter `prefactor · shape_coarse_i` to 6 coarse nodes.
604 //
605 // Dirichlet: zero-out the boundary nodes' velocity contribution.
606 // Freeslip: project the boundary nodes' velocity onto the tangent plane before use
607 // (the outward unit normal on a spherical shell is just the unit-sphere
608 // lateral coord already cached in coords_sh).
609 //
610 // Quadrature: Felippa 1x1 single point. Weights + points fetched once.
611
612 dense::Vec< ScalarT, 3 > quad_points[quadrature::quad_felippa_1x1_num_quad_points];
616 const ScalarT qw = quad_weights[0];
617
618 // Wedge node offset table: (ddx, ddy, ddr) per (wedge, local node 0..5). Same layout used
619 // by atomically_add_local_wedge_scalar_coefficients for the coarse scatter.
620 constexpr int WEDGE_NODE_OFF[2][6][3] = {
621 { { 0, 0, 0 }, { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 }, { 1, 0, 1 }, { 0, 1, 1 } },
622 { { 1, 1, 0 }, { 0, 1, 0 }, { 1, 0, 0 }, { 1, 1, 1 }, { 0, 1, 1 }, { 1, 0, 1 } } };
623
624 // Same lat hex-corner layout as the surface-coord gather: v0 is the "first" corner of each
625 // sub-wedge, v1 and v2 the other two. Only lat nodes matter for the Jacobian block.
626 const int n00 = node_id( tx, ty );
627 const int n01 = node_id( tx, ty + 1 );
628 const int n10 = node_id( tx + 1, ty );
629 const int n11 = node_id( tx + 1, ty + 1 );
630
631 for ( int pass = 0; pass < r_passes_; ++pass )
632 {
633 const int lvl0 = pass * r_tile_ + tr;
634 const int r_cell_abs = r0 + lvl0;
635 if ( r_cell_abs >= hex_rad_ )
636 break;
637
638 const ScalarT r_1 = r_sh( lvl0 );
639 const ScalarT r_2 = r_sh( lvl0 + 1 );
640
641 const bool at_cmb = has_flag( local_subdomain_id, x_cell, y_cell, r_cell_abs, CMB );
642 const bool at_surface = has_flag( local_subdomain_id, x_cell, y_cell, r_cell_abs + 1, SURFACE );
643
644 // BC classification, done once per cell.
645 bool treat_dirichlet = false;
646 bool treat_freeslip = false;
647 if ( at_cmb || at_surface )
648 {
649 const ShellBoundaryFlag sbf = at_cmb ? CMB : SURFACE;
650 const BoundaryConditionFlag bcf = get_boundary_condition_flag( bcs_, sbf );
651 treat_dirichlet = ( bcf == DIRICHLET );
652 if constexpr ( Freeslip )
653 treat_freeslip = ( bcf == FREESLIP );
654 }
655
656 // Boundary nodes live either at the cell's bottom (ddr == 0, CMB) or top (ddr == 1, SURFACE).
657 const int boundary_ddr = at_cmb ? 0 : 1;
658
659 const int fine_rad_wedge = r_cell_abs % 2;
660
661 for ( int w = 0; w < num_wedges_per_hex_cell; ++w )
662 {
663 const int v0 = ( w == 0 ? n00 : n11 );
664 const int v1 = ( w == 0 ? n10 : n01 );
665 const int v2 = ( w == 0 ? n01 : n10 );
666
667 // --- Jacobian (same formulas as EpsilonDivDivKerngen) ---
668 constexpr double ONE_THIRD = 1.0 / 3.0;
669 const double half_dr = 0.5 * ( r_2 - r_1 );
670 const double r_mid = 0.5 * ( r_1 + r_2 );
671
672 const double J_0_0 = r_mid * ( -coords_sh( v0, 0 ) + coords_sh( v1, 0 ) );
673 const double J_0_1 = r_mid * ( -coords_sh( v0, 0 ) + coords_sh( v2, 0 ) );
674 const double J_0_2 = half_dr * ( ONE_THIRD * ( coords_sh( v0, 0 ) + coords_sh( v1, 0 ) + coords_sh( v2, 0 ) ) );
675 const double J_1_0 = r_mid * ( -coords_sh( v0, 1 ) + coords_sh( v1, 1 ) );
676 const double J_1_1 = r_mid * ( -coords_sh( v0, 1 ) + coords_sh( v2, 1 ) );
677 const double J_1_2 = half_dr * ( ONE_THIRD * ( coords_sh( v0, 1 ) + coords_sh( v1, 1 ) + coords_sh( v2, 1 ) ) );
678 const double J_2_0 = r_mid * ( -coords_sh( v0, 2 ) + coords_sh( v1, 2 ) );
679 const double J_2_1 = r_mid * ( -coords_sh( v0, 2 ) + coords_sh( v2, 2 ) );
680 const double J_2_2 = half_dr * ( ONE_THIRD * ( coords_sh( v0, 2 ) + coords_sh( v1, 2 ) + coords_sh( v2, 2 ) ) );
681
682 const double J_det = J_0_0 * J_1_1 * J_2_2 - J_0_0 * J_1_2 * J_2_1
683 - J_0_1 * J_1_0 * J_2_2 + J_0_1 * J_1_2 * J_2_0
684 + J_0_2 * J_1_0 * J_2_1 - J_0_2 * J_1_1 * J_2_0;
685 const double abs_det = Kokkos::abs( J_det );
686 const double inv_det = 1.0 / J_det;
687
688 // J^{-T} (row d dot dN_ref = physical gradient component d)
689 const double i00 = inv_det * ( J_1_1 * J_2_2 - J_1_2 * J_2_1 );
690 const double i01 = inv_det * ( -J_1_0 * J_2_2 + J_1_2 * J_2_0 );
691 const double i02 = inv_det * ( J_1_0 * J_2_1 - J_1_1 * J_2_0 );
692 const double i10 = inv_det * ( -J_0_1 * J_2_2 + J_0_2 * J_2_1 );
693 const double i11 = inv_det * ( J_0_0 * J_2_2 - J_0_2 * J_2_0 );
694 const double i12 = inv_det * ( -J_0_0 * J_2_1 + J_0_1 * J_2_0 );
695 const double i20 = inv_det * ( J_0_1 * J_1_2 - J_0_2 * J_1_1 );
696 const double i21 = inv_det * ( -J_0_0 * J_1_2 + J_0_2 * J_1_0 );
697 const double i22 = inv_det * ( J_0_0 * J_1_1 - J_0_1 * J_1_0 );
698
699 // dN_ref[j] on the reference wedge at the Felippa centroid. Matches EpsilonDivDivKerngen.
700 constexpr double ONE_SIXTH = 1.0 / 6.0;
701 static constexpr double dN_ref[6][3] = {
702 { -0.5, -0.5, -ONE_SIXTH },
703 { 0.5, 0.0, -ONE_SIXTH },
704 { 0.0, 0.5, -ONE_SIXTH },
705 { -0.5, -0.5, ONE_SIXTH },
706 { 0.5, 0.0, ONE_SIXTH },
707 { 0.0, 0.5, ONE_SIXTH } };
708
709 // --- Fused gather: accumulate (∇·u)_phys at the single quadrature point ---
710 double div_u = 0.0;
711#pragma unroll
712 for ( int j = 0; j < num_nodes_per_wedge; ++j )
713 {
714 const int ddx = WEDGE_NODE_OFF[w][j][0];
715 const int ddy = WEDGE_NODE_OFF[w][j][1];
716 const int ddr = WEDGE_NODE_OFF[w][j][2];
717
718 const bool is_boundary_node = ( at_cmb || at_surface ) && ( ddr == boundary_ddr );
719
720 // Dirichlet: boundary-node velocity contributes nothing to div_u.
721 if ( treat_dirichlet && is_boundary_node )
722 continue;
723
724 const int nid = node_id( tx + ddx, ty + ddy );
725 const int lvl = lvl0 + ddr;
726
727 double u0 = src_sh( nid, 0, lvl );
728 double u1 = src_sh( nid, 1, lvl );
729 double u2 = src_sh( nid, 2, lvl );
730
731 // Freeslip: project boundary-node velocity onto the tangent plane.
732 // On the spherical shell the outward normal is the unit-sphere coord itself.
733 if constexpr ( Freeslip )
734 {
735 if ( treat_freeslip && is_boundary_node )
736 {
737 const double nx = coords_sh( nid, 0 );
738 const double ny = coords_sh( nid, 1 );
739 const double nz = coords_sh( nid, 2 );
740 const double un = u0 * nx + u1 * ny + u2 * nz;
741 u0 -= un * nx;
742 u1 -= un * ny;
743 u2 -= un * nz;
744 }
745 }
746
747 // Physical gradient of reference shape function j.
748 const double gx = dN_ref[j][0];
749 const double gy = dN_ref[j][1];
750 const double gz = dN_ref[j][2];
751 const double g0 = i00 * gx + i01 * gy + i02 * gz;
752 const double g1 = i10 * gx + i11 * gy + i12 * gz;
753 const double g2 = i20 * gx + i21 * gy + i22 * gz;
754
755 div_u += g0 * u0 + g1 * u1 + g2 * u2;
756 }
757
758 // --- Scatter: p_i += -qw * |det J| * shape_coarse_i * div_u ---
759 const int fine_lat_wedge = fine_lateral_wedge_idx( x_cell, y_cell, w );
760 const double prefactor = -qw * abs_det * div_u;
761
762 const int xc = x_cell / 2;
763 const int yc = y_cell / 2;
764 const int rc = r_cell_abs / 2;
765
766#pragma unroll
767 for ( int i = 0; i < num_nodes_per_wedge; ++i )
768 {
769 const double shape_i = shape_coarse( i, fine_rad_wedge, fine_lat_wedge, quad_points[0] );
770 const int ddx = WEDGE_NODE_OFF[w][i][0];
771 const int ddy = WEDGE_NODE_OFF[w][i][1];
772 const int ddr = WEDGE_NODE_OFF[w][i][2];
773 Kokkos::atomic_add( &dst_( local_subdomain_id, xc + ddx, yc + ddy, rc + ddr ),
774 prefactor * shape_i );
775 }
776 }
777 }
778 }
779};
780
781static_assert( linalg::OperatorLike< DivergenceKerngen< double > > );
782
783} // namespace terra::fe::wedge::operators::shell
Definition communication_plan.hpp:33
Definition divergence_kerngen.hpp:54
Kokkos::TeamPolicy<>::member_type Team
Definition divergence_kerngen.hpp:59
ScalarT ScalarType
Definition divergence_kerngen.hpp:58
void apply_impl(const SrcVectorType &src, DstVectorType &dst)
Definition divergence_kerngen.hpp:206
DivergenceKerngen(const grid::shell::DistributedDomain &domain_fine, const grid::shell::DistributedDomain &domain_coarse, const grid::Grid3DDataVec< ScalarT, 3 > &grid_fine, const grid::Grid2DDataScalar< ScalarT > &radii_fine, const grid::Grid4DDataScalar< grid::shell::ShellBoundaryFlag > &boundary_mask_fine, BoundaryConditions bcs, linalg::OperatorApplyMode operator_apply_mode=linalg::OperatorApplyMode::Replace, linalg::OperatorCommunicationMode operator_communication_mode=linalg::OperatorCommunicationMode::CommunicateAdditively)
Definition divergence_kerngen.hpp:121
const char * path_name() const
Definition divergence_kerngen.hpp:180
void force_slow_path()
Force the slow path. Useful for validation/testing the fast paths against a reference.
Definition divergence_kerngen.hpp:193
KernelPath kernel_path() const
Definition divergence_kerngen.hpp:190
void set_operator_apply_and_communication_modes(const linalg::OperatorApplyMode operator_apply_mode, const linalg::OperatorCommunicationMode operator_communication_mode)
Definition divergence_kerngen.hpp:195
size_t team_shmem_size(const int) const
Definition divergence_kerngen.hpp:263
KernelPath
Definition divergence_kerngen.hpp:62
Parallel data structure organizing the thick spherical shell metadata for distributed (MPI parallel) ...
Definition spherical_shell.hpp:2518
const std::map< SubdomainInfo, std::tuple< LocalSubdomainIdx, SubdomainNeighborhood > > & subdomains() const
Definition spherical_shell.hpp:2650
const DomainInfo & domain_info() const
Returns a const reference.
Definition spherical_shell.hpp:2647
Information about the thick spherical shell mesh.
Definition spherical_shell.hpp:780
int subdomain_num_nodes_radially() const
Equivalent to calling subdomain_num_nodes_radially( subdomain_refinement_level() )
Definition spherical_shell.hpp:861
int subdomain_num_nodes_per_side_laterally() const
Equivalent to calling subdomain_num_nodes_per_side_laterally( subdomain_refinement_level() )
Definition spherical_shell.hpp:852
Q1 scalar finite element vector on a distributed shell grid.
Definition vector_q1.hpp:21
const grid::Grid4DDataScalar< ScalarType > & grid_data() const
Get const reference to grid data.
Definition vector_q1.hpp:139
const grid::Grid4DDataVec< ScalarType, VecDim > & grid_data() const
Get const reference to grid data.
Definition vector_q1.hpp:288
Timer supporting RAII scope or manual stop.
Definition timer.hpp:342
void stop()
Stop the timer and record elapsed time.
Definition timer.hpp:364
J
Definition EpsilonDivDiv_kernel_gen.py:199
r_cell
Definition EpsilonDivDiv_kernel_gen.py:23
qw
Definition EpsilonDivDiv_kernel_gen.py:166
x_cell
Definition EpsilonDivDiv_kernel_gen.py:23
g2
Definition EpsilonDivDiv_kernel_gen.py:267
div_u
Definition EpsilonDivDiv_kernel_gen.py:359
local_subdomain_id
Definition EpsilonDivDiv_kernel_gen.py:23
g1
Definition EpsilonDivDiv_kernel_gen.py:266
J_det
Definition EpsilonDivDiv_kernel_gen.py:216
y_cell
Definition EpsilonDivDiv_kernel_gen.py:23
constexpr int idx(const int loop_idx, const int size, const grid::BoundaryPosition position, const grid::BoundaryDirection direction)
Definition buffer_copy_kernels.hpp:73
void send_recv_with_plan(const ShellBoundaryCommPlan< GridDataType > &plan, const GridDataType &data, SubdomainNeighborhoodSendRecvBuffer< typename GridDataType::value_type, grid::grid_data_vec_dim< GridDataType >() > &recv_buffers, CommunicationReduction reduction=CommunicationReduction::SUM)
Definition communication_plan.hpp:652
Definition boundary_mass.hpp:14
constexpr void quad_felippa_1x1_quad_points(dense::Vec< T, 3 >(&quad_points)[quad_felippa_1x1_num_quad_points])
Definition wedge/quadrature/quadrature.hpp:36
constexpr void quad_felippa_1x1_quad_weights(T(&quad_weights)[quad_felippa_1x1_num_quad_points])
Definition wedge/quadrature/quadrature.hpp:43
constexpr int quad_felippa_1x1_num_quad_points
Definition wedge/quadrature/quadrature.hpp:32
constexpr int num_nodes_per_wedge_surface
Definition kernel_helpers.hpp:6
void atomically_add_local_wedge_scalar_coefficients(const grid::Grid4DDataScalar< T > &global_coefficients, const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const dense::Vec< T, 6 >(&local_coefficients)[2])
Performs an atomic add of the two local wedge coefficient vectors of a hex cell into the global coeff...
Definition kernel_helpers.hpp:407
constexpr int fine_lateral_wedge_idx(const int x_cell_fine, const int y_cell_fine, const int wedge_idx_fine)
Returns the lateral wedge index with respect to a coarse grid wedge from the fine wedge indices.
Definition kernel_helpers.hpp:601
void wedge_surface_physical_coords(dense::Vec< T, 3 >(&wedge_surf_phy_coords)[num_wedges_per_hex_cell][num_nodes_per_wedge_surface], const grid::Grid3DDataVec< T, 3 > &lateral_grid, const int local_subdomain_id, const int x_cell, const int y_cell)
Extracts the (unit sphere) surface vertex coords of the two wedges of a hex cell.
Definition kernel_helpers.hpp:26
constexpr void reorder_local_dofs(const DoFOrdering doo_from, const DoFOrdering doo_to, dense::Vec< ScalarT, 18 > &dofs)
Definition kernel_helpers.hpp:619
constexpr T shape_coarse(const int coarse_node_idx, const int fine_radial_wedge_idx, const int fine_lateral_wedge_idx, const T xi_fine, const T eta_fine, const T zeta_fine)
Definition integrands.hpp:373
constexpr int num_wedges_per_hex_cell
Definition kernel_helpers.hpp:5
void extract_local_wedge_vector_coefficients(dense::Vec< T, 6 >(&local_coefficients)[2], const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int d, const grid::Grid4DDataVec< T, VecDim > &global_coefficients)
Extracts the local vector coefficients for the two wedges of a hex cell from the global coefficient v...
Definition kernel_helpers.hpp:356
constexpr int num_nodes_per_wedge
Definition kernel_helpers.hpp:7
constexpr dense::Vec< T, 3 > grad_shape(const int node_idx, const T xi, const T eta, const T zeta)
Gradient of the full shape function:
Definition integrands.hpp:228
constexpr dense::Mat< T, 3, 3 > jac(const dense::Vec< T, 3 > &p1_phy, const dense::Vec< T, 3 > &p2_phy, const dense::Vec< T, 3 > &p3_phy, const T r_1, const T r_2, const T xi, const T eta, const T zeta)
Definition integrands.hpp:657
dense::Vec< typename CoordsShellType::value_type, 3 > coords(const int subdomain, const int x, const int y, const int r, const CoordsShellType &coords_shell, const CoordsRadiiType &coords_radii)
Definition spherical_shell.hpp:2871
BoundaryConditionMapping[2] BoundaryConditions
Definition shell/bit_masks.hpp:37
ShellBoundaryFlag
FlagLike that indicates boundary types for the thick spherical shell.
Definition shell/bit_masks.hpp:12
BoundaryConditionFlag get_boundary_condition_flag(const BoundaryConditions bcs, ShellBoundaryFlag sbf)
Retrieve the boundary condition flag that is associated with a location in the shell e....
Definition shell/bit_masks.hpp:42
BoundaryConditionFlag
FlagLike that indicates the type of boundary condition
Definition shell/bit_masks.hpp:25
Kokkos::View< ScalarType ***[VecDim], Layout > Grid3DDataVec
Definition grid_types.hpp:42
Kokkos::View< ScalarType ****, Layout > Grid4DDataScalar
Definition grid_types.hpp:27
Kokkos::View< ScalarType **, Layout > Grid2DDataScalar
Definition grid_types.hpp:21
dense::Mat< ScalarType, 3, 3 > trafo_mat_cartesian_to_normal_tangential(const dense::Vec< ScalarType, 3 > &n_input)
Constructs a robust orthonormal transformation matrix from Cartesian to (normal–tangential–tangential...
Definition local_basis_trafo_normal_tangential.hpp:36
OperatorApplyMode
Modes for applying an operator to a vector.
Definition operator.hpp:30
@ Replace
Overwrite the destination vector.
OperatorCommunicationMode
Modes for communication during operator application.
Definition operator.hpp:40
@ CommunicateAdditively
Communicate and add results.
constexpr bool has_flag(E mask_value, E flag) noexcept
Checks if a bitmask value contains a specific flag.
Definition bit_masking.hpp:43
detail::PrefixCout logroot([]() { return detail::log_prefix();})
std::ostream subclass that just logs on root and adds a timestamp for each line.