70 using Team = Kokkos::TeamPolicy<>::member_type;
87 BoundaryConditions bcs_;
98 int local_subdomains_;
112 void update_kernel_path_flag_host_only()
114 if constexpr ( std::is_same_v< Kokkos::DefaultExecutionSpace, Kokkos::Serial > )
119 const BoundaryConditionFlag cmb_bc = get_boundary_condition_flag( bcs_, CMB );
120 const BoundaryConditionFlag surface_bc = get_boundary_condition_flag( bcs_, SURFACE );
121 const bool has_freeslip = ( cmb_bc == FREESLIP ) || ( surface_bc == FREESLIP );
132 BoundaryConditions bcs,
136 : domain_fine_( domain_fine )
137 , domain_coarse_( domain_coarse )
138 , grid_fine_( grid_fine )
139 , radii_( radii_fine )
140 , boundary_mask_fine_( boundary_mask_fine )
141 , operator_apply_mode_( operator_apply_mode )
142 , operator_communication_mode_( operator_communication_mode )
143 , recv_buffers_( domain_fine )
144 , comm_plan_( domain_fine )
150 local_subdomains_ = domain_fine_.
subdomains().size();
154 if constexpr ( std::is_same_v< Kokkos::DefaultExecutionSpace, Kokkos::Serial > )
156 lat_tile_ = 1; r_tile_ = 1; r_passes_ = 1;
158#ifdef KOKKOS_ENABLE_OPENMP
159 else if constexpr ( std::is_same_v< Kokkos::DefaultExecutionSpace, Kokkos::OpenMP > )
161 const int max_team = std::min( Kokkos::OpenMP().concurrency(),
162 static_cast< int >( Kokkos::Impl::HostThreadTeamData::max_team_members ) );
163 if ( max_team >= 64 ) { lat_tile_ = 4; r_tile_ = 4; r_passes_ = 4; }
164 else if ( max_team >= 16 ) { lat_tile_ = 4; r_tile_ = 1; r_passes_ = 16; }
165 else { lat_tile_ = 1; r_tile_ = 1; r_passes_ = 1; }
170 lat_tile_ = 4; r_tile_ = 8; r_passes_ = 2;
172 r_tile_block_ = r_tile_ * r_passes_;
173 lat_tiles_ = ( hex_lat_ + lat_tile_ - 1 ) / lat_tile_;
174 r_tiles_ = ( hex_rad_ + r_tile_block_ - 1 ) / r_tile_block_;
175 team_size_ = lat_tile_ * lat_tile_ * r_tile_;
176 blocks_ = local_subdomains_ * lat_tiles_ * lat_tiles_ * r_tiles_;
178 update_kernel_path_flag_host_only();
180 util::logroot <<
"[GradientKerngen] tile=(" << lat_tile_ <<
"," << lat_tile_ <<
"," << r_tile_
181 <<
"), r_passes=" << r_passes_ <<
", team=" << team_size_ <<
", blocks=" << blocks_
182 <<
", path=" <<
path_name() << std::endl;
187 switch ( kernel_path_ )
191 default:
return "fast-dirichlet-neumann";
203 operator_apply_mode_ = operator_apply_mode;
204 operator_communication_mode_ = operator_communication_mode;
220 Kokkos::TeamPolicy<> policy( blocks_, team_size_ );
223 policy.set_scratch_size( 0, Kokkos::PerTeam(
team_shmem_size( team_size_ ) ) );
228 Kokkos::parallel_for(
229 "gradient_apply_kernel_slow", policy, KOKKOS_CLASS_LAMBDA(
const Team& team ) {
230 this->run_team_slow( team );
235 Kokkos::parallel_for(
236 "gradient_apply_kernel_fast_fs", policy, KOKKOS_CLASS_LAMBDA(
const Team& team ) {
237 this->
template run_team_fast<
true >( team );
242 Kokkos::TeamPolicy< Kokkos::LaunchBounds< 128, 5 > > dn_policy( blocks_, team_size_ );
243 dn_policy.set_scratch_size( 0, Kokkos::PerTeam(
team_shmem_size( team_size_ ) ) );
244 Kokkos::parallel_for(
245 "gradient_apply_kernel_fast_dn", dn_policy, KOKKOS_CLASS_LAMBDA(
const Team& team ) {
246 this->
template run_team_fast<
false >( team );
260 KOKKOS_INLINE_FUNCTION
263 const int nlev = r_tile_block_ + 1;
264 const int n = lat_tile_ + 1;
265 const int nxy = n * n;
266 const int n_c = ( lat_tile_ / 2 ) + 1;
267 const int nxy_c = n_c * n_c;
268 const int nlev_c = ( r_tile_block_ / 2 ) + 1;
270 const size_t nscalars = size_t( nxy ) * 3 + size_t( nlev ) + size_t( nxy_c ) * nlev_c;
275 KOKKOS_INLINE_FUNCTION
276 void decode_team_indices(
278 int& local_subdomain_id,
279 int& x0,
int& y0,
int& r0,
280 int& tx,
int& ty,
int& tr,
281 int& x_cell,
int& y_cell,
int& r_cell )
const
283 int tmp = team.league_rank();
284 const int r_tile_id = tmp % r_tiles_;
286 const int lat_y_id = tmp % lat_tiles_;
288 const int lat_x_id = tmp % lat_tiles_;
290 local_subdomain_id = tmp;
292 x0 = lat_x_id * lat_tile_;
293 y0 = lat_y_id * lat_tile_;
294 r0 = r_tile_id * r_tile_block_;
296 const int tid = team.team_rank();
298 tx = ( tid / r_tile_ ) % lat_tile_;
299 ty = tid / ( r_tile_ * lat_tile_ );
306 KOKKOS_INLINE_FUNCTION
316 KOKKOS_INLINE_FUNCTION
317 void run_team_slow(
const Team& team )
const
319 int local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell;
320 decode_team_indices( team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell );
325 for (
int pass = 0; pass < r_passes_; ++pass )
327 const int r_cell_pass = r0 + pass * r_tile_ + tr;
328 if ( r_cell_pass >= hex_rad_ )
330 if ( x_cell >= hex_lat_ || y_cell >= hex_lat_ )
333 process_cell_legacy( local_subdomain_id, x_cell, y_cell, r_cell_pass );
338 KOKKOS_INLINE_FUNCTION
339 void process_cell_legacy(
int s,
int x_cell,
int y_cell,
int r_cell )
const
344 const ScalarT r_1 = radii_( s, r_cell );
345 const ScalarT r_2 = radii_( s, r_cell + 1 );
353 const int fine_radial_wedge_index = r_cell % 2;
357 for (
int q = 0; q < num_quad_points; q++ )
362 const auto J =
jac( wedge_phy_surf[wedge], r_1, r_2, quad_points[q] );
363 const auto det = Kokkos::abs(
J.det() );
364 const auto J_inv_transposed =
J.inv().transposed();
368 const auto grad_i =
grad_shape( i, quad_points[q] );
372 shape_coarse( j, fine_radial_wedge_index, fine_lateral_wedge_index, quad_points[q] );
373 for (
int d = 0; d < 3; d++ )
375 A[wedge]( d * 6 + i, j ) +=
376 quad_weights[q] * ( -( ( J_inv_transposed * grad_i )( d ) * shape_j ) * det );
386 bool at_cmb =
util::has_flag( boundary_mask_fine_( s, x_cell, y_cell, r_cell ), CMB );
387 bool at_surface =
util::has_flag( boundary_mask_fine_( s, x_cell, y_cell, r_cell + 1 ), SURFACE );
389 dense::Mat< ScalarT, 18, 6 > boundary_mask;
390 boundary_mask.fill( 1.0 );
393 bool freeslip_reorder =
false;
395 if ( at_cmb || at_surface )
400 if ( bcf == DIRICHLET )
402 for (
int dimi = 0; dimi < 3; ++dimi )
405 if ( ( at_cmb && i < 3 ) || ( at_surface && i >= 3 ) )
408 else if ( bcf == FREESLIP )
410 freeslip_reorder =
true;
413 for (
int wedge = 0; wedge < 2; ++wedge )
415 for (
int dimi = 0; dimi < 3; ++dimi )
417 A_tmp[wedge]( node_idxi * 3 + dimi, node_idxj ) =
420 constexpr int layer_hex_offset_x[2][3] = { { 0, 1, 0 }, { 1, 0, 1 } };
421 constexpr int layer_hex_offset_y[2][3] = { { 0, 0, 1 }, { 1, 1, 0 } };
423 for (
int wedge = 0; wedge < 2; ++wedge )
425 for (
int i = 0; i < 18; ++i )
426 R[wedge]( i, i ) = 1.0;
428 for (
int bn = 0; bn < 3; ++bn )
432 x_cell + layer_hex_offset_x[wedge][bn],
433 y_cell + layer_hex_offset_y[wedge][bn],
434 r_cell + ( at_cmb ? 0 : 1 ),
439 const int offset_in_R = at_cmb ? 0 : 9;
440 for (
int dimi = 0; dimi < 3; ++dimi )
441 for (
int dimj = 0; dimj < 3; ++dimj )
442 R[wedge]( offset_in_R + bn * 3 + dimi, offset_in_R + bn * 3 + dimj ) = R_i( dimi, dimj );
445 A[wedge] = R[wedge] * A_tmp[wedge];
447 const int node_start = at_surface ? 3 : 0;
448 const int node_end = at_surface ? 6 : 3;
449 for (
int node_idx = node_start; node_idx < node_end; ++node_idx )
451 const int idx = node_idx * 3;
452 for (
int k = 0; k < 6; ++k )
453 boundary_mask( idx, k ) = 0.0;
457 else if ( bcf == NEUMANN )
464 A[wedge].hadamard_product( boundary_mask );
467 dst_loc[0] = A[0] * src[0];
468 dst_loc[1] = A[1] * src[1];
470 if ( freeslip_reorder )
473 dst_tmp[0] = R[0].transposed() * dst_loc[0];
474 dst_tmp[1] = R[1].transposed() * dst_loc[1];
475 for (
int i = 0; i < 18; ++i )
477 dst_loc[0]( i ) = dst_tmp[0]( i );
478 dst_loc[1]( i ) = dst_tmp[1]( i );
484 for (
int d = 0; d < 3; d++ )
487 dst_d[0] = dst_loc[0].template slice< 6 >( d * 6 );
488 dst_d[1] = dst_loc[1].template slice< 6 >( d * 6 );
490 dst_, s, x_cell, y_cell, r_cell, d, dst_d );
498 template <
bool Freeslip >
499 KOKKOS_INLINE_FUNCTION
500 void run_team_fast(
const Team& team )
const
503 decode_team_indices( team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell );
505 const int nlev = r_tile_block_ + 1;
506 const int n = lat_tile_ + 1;
507 const int nxy = n * n;
508 const int n_c = ( lat_tile_ / 2 ) + 1;
509 const int nxy_c = n_c * n_c;
510 const int nlev_c = ( r_tile_block_ / 2 ) + 1;
513 reinterpret_cast< double*
>( team.team_shmem().get_shmem(
team_shmem_size( team.team_size() ) ) );
515 using ScratchCoords =
516 Kokkos::View< double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
518 Kokkos::View< double*, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
520 Kokkos::View< double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
522 ScratchCoords coords_sh( shmem, nxy, 3 );
524 ScratchR r_sh( shmem, nlev );
526 ScratchP p_sh( shmem, nxy_c, nlev_c );
528 auto node_id = [&](
int nx,
int ny ) ->
int {
return nx + n * ny; };
529 auto node_id_c = [&](
int nx,
int ny ) ->
int {
return nx + n_c * ny; };
533 const int x0_c = x0 / 2;
534 const int y0_c = y0 / 2;
535 const int r0_c = r0 / 2;
538 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, nxy ), [&](
int id ) {
539 const int dxn =
id % n;
540 const int dyn =
id / n;
541 const int xi = x0 + dxn;
542 const int yi = y0 + dyn;
543 if ( xi <= hex_lat_ && yi <= hex_lat_ )
545 coords_sh(
id, 0 ) = grid_fine_( local_subdomain_id, xi, yi, 0 );
546 coords_sh(
id, 1 ) = grid_fine_( local_subdomain_id, xi, yi, 1 );
547 coords_sh(
id, 2 ) = grid_fine_( local_subdomain_id, xi, yi, 2 );
551 coords_sh(
id, 0 ) = coords_sh(
id, 1 ) = coords_sh(
id, 2 ) = 0.0;
556 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, nlev ), [&](
int lvl ) {
557 const int rr = r0 + lvl;
558 r_sh( lvl ) = ( rr <= hex_rad_ ) ? radii_( local_subdomain_id, rr ) : 0.0;
562 const int total_coarse = nxy_c * nlev_c;
563 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, total_coarse ), [&](
int t ) {
564 const int node_c = t / nlev_c;
565 const int lvl_c = t - node_c * nlev_c;
566 const int dxn = node_c % n_c;
567 const int dyn = node_c / n_c;
568 const int xi_c = x0_c + dxn;
569 const int yi_c = y0_c + dyn;
570 const int rr_c = r0_c + lvl_c;
573 if ( xi_c <
static_cast< int >( src_.extent( 1 ) ) &&
574 yi_c <
static_cast< int >( src_.extent( 2 ) ) &&
575 rr_c <
static_cast< int >( src_.extent( 3 ) ) )
577 p_sh( node_c, lvl_c ) = src_( local_subdomain_id, xi_c, yi_c, rr_c );
581 p_sh( node_c, lvl_c ) = 0.0;
589 if ( x_cell >= hex_lat_ || y_cell >= hex_lat_ )
596 const ScalarT
qw = quad_weights[0];
598 constexpr int WEDGE_NODE_OFF[2][6][3] = {
599 { { 0, 0, 0 }, { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 }, { 1, 0, 1 }, { 0, 1, 1 } },
600 { { 1, 1, 0 }, { 0, 1, 0 }, { 1, 0, 0 }, { 1, 1, 1 }, { 0, 1, 1 }, { 1, 0, 1 } } };
602 const int n00 = node_id( tx, ty );
603 const int n01 = node_id( tx, ty + 1 );
604 const int n10 = node_id( tx + 1, ty );
605 const int n11 = node_id( tx + 1, ty + 1 );
607 for (
int pass = 0; pass < r_passes_; ++pass )
609 const int lvl0 = pass * r_tile_ + tr;
610 const int r_cell_abs = r0 + lvl0;
611 if ( r_cell_abs >= hex_rad_ )
614 const ScalarT r_1 = r_sh( lvl0 );
615 const ScalarT r_2 = r_sh( lvl0 + 1 );
617 const bool at_cmb = has_flag( local_subdomain_id, x_cell, y_cell, r_cell_abs, CMB );
618 const bool at_surface = has_flag( local_subdomain_id, x_cell, y_cell, r_cell_abs + 1, SURFACE );
620 bool treat_dirichlet =
false;
621 bool treat_freeslip =
false;
622 if ( at_cmb || at_surface )
627 if constexpr ( Freeslip )
628 treat_freeslip = ( bcf == FREESLIP );
630 const int boundary_ddr = at_cmb ? 0 : 1;
632 const int fine_rad_wedge = r_cell_abs % 2;
635 const int cxc_in_tile = (
x_cell - x0 ) / 2;
636 const int cyc_in_tile = (
y_cell - y0 ) / 2;
637 const int crc_in_tile = lvl0 / 2;
641 const int v0 = ( w == 0 ? n00 : n11 );
642 const int v1 = ( w == 0 ? n10 : n01 );
643 const int v2 = ( w == 0 ? n01 : n10 );
645 constexpr double ONE_THIRD = 1.0 / 3.0;
646 const double half_dr = 0.5 * ( r_2 - r_1 );
647 const double r_mid = 0.5 * ( r_1 + r_2 );
649 const double J_0_0 = r_mid * ( -coords_sh( v0, 0 ) + coords_sh( v1, 0 ) );
650 const double J_0_1 = r_mid * ( -coords_sh( v0, 0 ) + coords_sh( v2, 0 ) );
651 const double J_0_2 = half_dr * ( ONE_THIRD * ( coords_sh( v0, 0 ) + coords_sh( v1, 0 ) + coords_sh( v2, 0 ) ) );
652 const double J_1_0 = r_mid * ( -coords_sh( v0, 1 ) + coords_sh( v1, 1 ) );
653 const double J_1_1 = r_mid * ( -coords_sh( v0, 1 ) + coords_sh( v2, 1 ) );
654 const double J_1_2 = half_dr * ( ONE_THIRD * ( coords_sh( v0, 1 ) + coords_sh( v1, 1 ) + coords_sh( v2, 1 ) ) );
655 const double J_2_0 = r_mid * ( -coords_sh( v0, 2 ) + coords_sh( v1, 2 ) );
656 const double J_2_1 = r_mid * ( -coords_sh( v0, 2 ) + coords_sh( v2, 2 ) );
657 const double J_2_2 = half_dr * ( ONE_THIRD * ( coords_sh( v0, 2 ) + coords_sh( v1, 2 ) + coords_sh( v2, 2 ) ) );
659 const double J_det = J_0_0 * J_1_1 * J_2_2 - J_0_0 * J_1_2 * J_2_1
660 - J_0_1 * J_1_0 * J_2_2 + J_0_1 * J_1_2 * J_2_0
661 + J_0_2 * J_1_0 * J_2_1 - J_0_2 * J_1_1 * J_2_0;
662 const double abs_det = Kokkos::abs( J_det );
663 const double inv_det = 1.0 /
J_det;
665 const double i00 = inv_det * ( J_1_1 * J_2_2 - J_1_2 * J_2_1 );
666 const double i01 = inv_det * ( -J_1_0 * J_2_2 + J_1_2 * J_2_0 );
667 const double i02 = inv_det * ( J_1_0 * J_2_1 - J_1_1 * J_2_0 );
668 const double i10 = inv_det * ( -J_0_1 * J_2_2 + J_0_2 * J_2_1 );
669 const double i11 = inv_det * ( J_0_0 * J_2_2 - J_0_2 * J_2_0 );
670 const double i12 = inv_det * ( -J_0_0 * J_2_1 + J_0_1 * J_2_0 );
671 const double i20 = inv_det * ( J_0_1 * J_1_2 - J_0_2 * J_1_1 );
672 const double i21 = inv_det * ( -J_0_0 * J_1_2 + J_0_2 * J_1_0 );
673 const double i22 = inv_det * ( J_0_0 * J_1_1 - J_0_1 * J_1_0 );
675 constexpr double ONE_SIXTH = 1.0 / 6.0;
676 static constexpr double dN_ref[6][3] = {
677 { -0.5, -0.5, -ONE_SIXTH },
678 { 0.5, 0.0, -ONE_SIXTH },
679 { 0.0, 0.5, -ONE_SIXTH },
680 { -0.5, -0.5, ONE_SIXTH },
681 { 0.5, 0.0, ONE_SIXTH },
682 { 0.0, 0.5, ONE_SIXTH } };
689 double p_interp = 0.0;
693 const int cddx = WEDGE_NODE_OFF[w][j][0];
694 const int cddy = WEDGE_NODE_OFF[w][j][1];
695 const int cddr = WEDGE_NODE_OFF[w][j][2];
696 const int nidc = node_id_c( cxc_in_tile + cddx, cyc_in_tile + cddy );
697 const int lvlc = crc_in_tile + cddr;
698 const double pj = p_sh( nidc, lvlc );
699 const double sj =
shape_coarse( j, fine_rad_wedge, fine_lat_wedge, quad_points[0] );
703 const double prefactor = -
qw * abs_det * p_interp;
706 const int xc_base =
x_cell;
707 const int yc_base =
y_cell;
708 const int rc_base = r_cell_abs;
713 const int ddx = WEDGE_NODE_OFF[w][i][0];
714 const int ddy = WEDGE_NODE_OFF[w][i][1];
715 const int ddr = WEDGE_NODE_OFF[w][i][2];
717 const bool is_boundary_node = ( at_cmb || at_surface ) && ( ddr == boundary_ddr );
720 if ( treat_dirichlet && is_boundary_node )
724 const double gx = dN_ref[i][0];
725 const double gy = dN_ref[i][1];
726 const double gz = dN_ref[i][2];
727 const double g0 = i00 * gx + i01 * gy + i02 * gz;
728 const double g1 = i10 * gx + i11 * gy + i12 * gz;
729 const double g2 = i20 * gx + i21 * gy + i22 * gz;
731 double c0 = prefactor * g0;
732 double c1 = prefactor *
g1;
733 double c2 = prefactor *
g2;
736 if constexpr ( Freeslip )
738 if ( treat_freeslip && is_boundary_node )
740 const int nid = node_id( tx + ddx, ty + ddy );
741 const double nx = coords_sh( nid, 0 );
742 const double ny = coords_sh( nid, 1 );
743 const double nz = coords_sh( nid, 2 );
744 const double cn = c0 * nx + c1 * ny + c2 * nz;
751 Kokkos::atomic_add( &dst_( local_subdomain_id, xc_base + ddx, yc_base + ddy, rc_base + ddr, 0 ), c0 );
752 Kokkos::atomic_add( &dst_( local_subdomain_id, xc_base + ddx, yc_base + ddy, rc_base + ddr, 1 ), c1 );
753 Kokkos::atomic_add( &dst_( local_subdomain_id, xc_base + ddx, yc_base + ddy, rc_base + ddr, 2 ), c2 );