345 int tmp = team.league_rank();
347 const int r_tile_id = tmp % r_tiles_;
350 const int lat_y_id = tmp % lat_tiles_;
353 const int lat_x_id = tmp % lat_tiles_;
356 const int local_subdomain_id = tmp;
358 const int x0 = lat_x_id * lat_tile_;
359 const int y0 = lat_y_id * lat_tile_;
360 const int r0 = r_tile_id * r_tile_;
363 const int tid = team.team_rank();
364 const int tx = tid % lat_tile_;
365 const int ty = (tid / lat_tile_) % lat_tile_;
366 const int tr = tid / (lat_tile_ * lat_tile_);
368 if (tr >= r_tile_)
return;
370 const int x_cell = x0 + tx;
371 const int y_cell = y0 + ty;
372 const int r_cell = r0 + tr;
377 const int nlev = r_tile_ + 1;
378 const int nxy = (lat_tile_ + 1) * (lat_tile_ + 1);
380 double* shmem =
reinterpret_cast<double*
>(
383 using ScratchCoords =
384 Kokkos::View<double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged>;
386 Kokkos::View<double***, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged>;
388 Kokkos::View<double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged>;
390 ScratchCoords coords_sh(shmem, nxy, 3); shmem += nxy * 3;
391 ScratchSrc src_sh (shmem, nxy, 3, nlev); shmem += nxy * 3 * nlev;
392 ScratchK k_sh (shmem, nxy, nlev); shmem += nxy * nlev;
395 Kokkos::View<double*, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged>(
398 auto node_id = [&](
int nx,
int ny) ->
int {
return nx + (lat_tile_ + 1) * ny; };
402 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, nxy), [&](
int n) {
403 const int dxn = n % (lat_tile_ + 1) ;
404 const int dyn = n / (lat_tile_ + 1) ;
405 const int xi = x0 + dxn;
406 const int yi = y0 + dyn;
408 if (xi <= hex_lat_ && yi <= hex_lat_) {
409 coords_sh(n,0) = grid_(local_subdomain_id, xi, yi, 0);
410 coords_sh(n,1) = grid_(local_subdomain_id, xi, yi, 1);
411 coords_sh(n,2) = grid_(local_subdomain_id, xi, yi, 2);
413 coords_sh(n,0) = coords_sh(n,1) = coords_sh(n,2) = 0.0;
418 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, nlev), [&](
int lvl) {
419 const int rr = r0 + lvl;
420 r_sh(lvl) = (rr <= hex_rad_) ? radii_(local_subdomain_id, rr) : 0.0;
424 const int total_pairs = nxy * nlev;
425 Kokkos::parallel_for(Kokkos::TeamThreadRange(team, total_pairs), [&](
int t) {
426 const int lvl = t / nxy;
427 const int node = t - lvl * nxy;
429 const int dxn = node % (lat_tile_ + 1) ;
430 const int dyn = node / (lat_tile_ + 1) ;
432 const int xi = x0 + dxn;
433 const int yi = y0 + dyn;
434 const int rr = r0 + lvl;
436 if (xi <= hex_lat_ && yi <= hex_lat_ && rr <= hex_rad_) {
437 k_sh(node, lvl) = k_(local_subdomain_id, xi, yi, rr);
438 src_sh(node, 0, lvl) = src_(local_subdomain_id, xi, yi, rr, 0);
439 src_sh(node, 1, lvl) = src_(local_subdomain_id, xi, yi, rr, 1);
440 src_sh(node, 2, lvl) = src_(local_subdomain_id, xi, yi, rr, 2);
442 k_sh(node, lvl) = 0.0;
443 src_sh(node,0,lvl) = src_sh(node,1,lvl) = src_sh(node,2,lvl) = 0.0;
452 if (x_cell < hex_lat_ && y_cell < hex_lat_ && r_cell < hex_rad_) {
455 const double r_0 = r_sh( lvl0 );
456 const double r_1 = r_sh( lvl0 + 1 );
458 const bool at_cmb =
has_flag( local_subdomain_id, x_cell, y_cell, r_cell, CMB );
459 const bool at_surface =
has_flag( local_subdomain_id, x_cell, y_cell, r_cell + 1, SURFACE );
461 const bool at_boundary = at_cmb || at_surface;
462 bool treat_boundary =
false;
465 const ShellBoundaryFlag sbf = at_cmb ? CMB : SURFACE;
466 treat_boundary = ( get_boundary_condition_flag( bcs_, sbf ) == DIRICHLET );
469 const int cmb_shift = ( ( at_boundary && treat_boundary && ( !diagonal_ ) && at_cmb ) ? 3 : 0 );
470 const int surface_shift = ( ( at_boundary && treat_boundary && ( !diagonal_ ) && at_surface ) ? 3 : 0 );
473 static constexpr int WEDGE_NODE_OFF[2][6][3] = {
474 { { 0, 0, 0 }, { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 }, { 1, 0, 1 }, { 0, 1, 1 } },
475 { { 1, 1, 0 }, { 0, 1, 0 }, { 1, 0, 0 }, { 1, 1, 1 }, { 0, 1, 1 }, { 1, 0, 1 } } };
477 static constexpr int WEDGE_TO_UNIQUE[2][6] = {
478 { 0, 1, 2, 3, 4, 5 },
482 constexpr double ONE_THIRD = 1.0 / 3.0;
483 constexpr double ONE_SIXTH = 1.0 / 6.0;
484 constexpr double NEG_TWO_THIRDS = -0.66666666666666663;
486 static constexpr double dN_ref[6][3] = {
487 { -0.5, -0.5, -ONE_SIXTH },
488 { 0.5, 0.0, -ONE_SIXTH },
489 { 0.0, 0.5, -ONE_SIXTH },
490 { -0.5, -0.5, ONE_SIXTH },
491 { 0.5, 0.0, ONE_SIXTH },
492 { 0.0, 0.5, ONE_SIXTH } };
497 const int n00 = node_id( tx, ty );
498 const int n01 = node_id( tx, ty + 1 );
499 const int n10 = node_id( tx + 1, ty );
500 const int n11 = node_id( tx + 1, ty + 1 );
506 ws[0][0][0] = coords_sh( n00, 0 ); ws[0][0][1] = coords_sh( n00, 1 ); ws[0][0][2] = coords_sh( n00, 2 );
507 ws[0][1][0] = coords_sh( n10, 0 ); ws[0][1][1] = coords_sh( n10, 1 ); ws[0][1][2] = coords_sh( n10, 2 );
508 ws[0][2][0] = coords_sh( n01, 0 ); ws[0][2][1] = coords_sh( n01, 1 ); ws[0][2][2] = coords_sh( n01, 2 );
511 ws[1][0][0] = coords_sh( n11, 0 ); ws[1][0][1] = coords_sh( n11, 1 ); ws[1][0][2] = coords_sh( n11, 2 );
512 ws[1][1][0] = coords_sh( n01, 0 ); ws[1][1][1] = coords_sh( n01, 1 ); ws[1][1][2] = coords_sh( n01, 2 );
513 ws[1][2][0] = coords_sh( n10, 0 ); ws[1][2][1] = coords_sh( n10, 1 ); ws[1][2][2] = coords_sh( n10, 2 );
516 double dst8[3][8] = { 0.0 };
518 for (
int w = 0; w < 2; ++w )
524 for (
int node = 0; node < 6; ++node )
526 const int ddx = WEDGE_NODE_OFF[w][node][0];
527 const int ddy = WEDGE_NODE_OFF[w][node][1];
528 const int ddr = WEDGE_NODE_OFF[w][node][2];
530 const int nid = node_id( tx + ddx, ty + ddy );
531 const int lvl = lvl0 + ddr;
533 k_sum += k_sh( nid, lvl );
535 const double k_eval = ONE_SIXTH * k_sum;
539 double i00, i01, i02;
540 double i10, i11, i12;
541 double i20, i21, i22;
544 const double half_dr = 0.5 * ( r_1 - r_0 );
545 const double r_mid = 0.5 * ( r_0 + r_1 );
547 const double J_0_0 = r_mid * ( -ws[w][0][0] + ws[w][1][0] );
548 const double J_0_1 = r_mid * ( -ws[w][0][0] + ws[w][2][0] );
549 const double J_0_2 = half_dr * ( ONE_THIRD * ( ws[w][0][0] + ws[w][1][0] + ws[w][2][0] ) );
551 const double J_1_0 = r_mid * ( -ws[w][0][1] + ws[w][1][1] );
552 const double J_1_1 = r_mid * ( -ws[w][0][1] + ws[w][2][1] );
553 const double J_1_2 = half_dr * ( ONE_THIRD * ( ws[w][0][1] + ws[w][1][1] + ws[w][2][1] ) );
555 const double J_2_0 = r_mid * ( -ws[w][0][2] + ws[w][1][2] );
556 const double J_2_1 = r_mid * ( -ws[w][0][2] + ws[w][2][2] );
557 const double J_2_2 = half_dr * ( ONE_THIRD * ( ws[w][0][2] + ws[w][1][2] + ws[w][2][2] ) );
559 const double J_det = J_0_0 * J_1_1 * J_2_2 - J_0_0 * J_1_2 * J_2_1 -
560 J_0_1 * J_1_0 * J_2_2 + J_0_1 * J_1_2 * J_2_0 +
561 J_0_2 * J_1_0 * J_2_1 - J_0_2 * J_1_1 * J_2_0;
563 const double invJ = 1.0 / J_det;
565 i00 = invJ * ( J_1_1 * J_2_2 - J_1_2 * J_2_1 );
566 i01 = invJ * ( -J_1_0 * J_2_2 + J_1_2 * J_2_0 );
567 i02 = invJ * ( J_1_0 * J_2_1 - J_1_1 * J_2_0 );
569 i10 = invJ * ( -J_0_1 * J_2_2 + J_0_2 * J_2_1 );
570 i11 = invJ * ( J_0_0 * J_2_2 - J_0_2 * J_2_0 );
571 i12 = invJ * ( -J_0_0 * J_2_1 + J_0_1 * J_2_0 );
573 i20 = invJ * ( J_0_1 * J_1_2 - J_0_2 * J_1_1 );
574 i21 = invJ * ( -J_0_0 * J_1_2 + J_0_2 * J_1_0 );
575 i22 = invJ * ( J_0_0 * J_1_1 - J_0_1 * J_1_0 );
577 wJ = Kokkos::abs( J_det );
580 const double kwJ = k_eval * wJ;
583 double gu10 = 0.0, gu11 = 0.0;
584 double gu20 = 0.0, gu21 = 0.0, gu22 = 0.0;
589 for (
int dimj = 0; dimj < 3; ++dimj )
592 for (
int node_idx = cmb_shift; node_idx < 6 - surface_shift; ++node_idx )
594 const double gx = dN_ref[node_idx][0];
595 const double gy = dN_ref[node_idx][1];
596 const double gz = dN_ref[node_idx][2];
598 const double g0 = i00 * gx + i01 * gy + i02 * gz;
599 const double g1 = i10 * gx + i11 * gy + i12 * gz;
600 const double g2 = i20 * gx + i21 * gy + i22 * gz;
602 double E00, E11, E22, sym01, sym02, sym12, gdd;
603 column_grad_to_sym( dimj, g0, g1, g2, E00, E11, E22, sym01, sym02, sym12, gdd );
605 const int ddx = WEDGE_NODE_OFF[w][node_idx][0];
606 const int ddy = WEDGE_NODE_OFF[w][node_idx][1];
607 const int ddr = WEDGE_NODE_OFF[w][node_idx][2];
609 const int nid = node_id( tx + ddx, ty + ddy );
610 const int lvl = lvl0 + ddr;
612 const double s = src_sh( nid, dimj, lvl );
625 for (
int dimi = 0; dimi < 3; ++dimi )
628 for (
int node_idx = cmb_shift; node_idx < 6 - surface_shift; ++node_idx )
630 const double gx = dN_ref[node_idx][0];
631 const double gy = dN_ref[node_idx][1];
632 const double gz = dN_ref[node_idx][2];
634 const double g0 = i00 * gx + i01 * gy + i02 * gz;
635 const double g1 = i10 * gx + i11 * gy + i12 * gz;
636 const double g2 = i20 * gx + i21 * gy + i22 * gz;
638 double E00, E11, E22, sym01, sym02, sym12, gdd;
639 column_grad_to_sym( dimi, g0, g1, g2, E00, E11, E22, sym01, sym02, sym12, gdd );
641 const double pairing0 = 2.0 * sym01;
642 const double pairing1 = 2.0 * sym02;
643 const double pairing2 = 2.0 * sym12;
645 const int u = WEDGE_TO_UNIQUE[w][node_idx];
647 dst8[dimi][u] += kwJ * ( NEG_TWO_THIRDS * div_u * gdd + 2.0 * pairing0 * gu10 +
648 2.0 * pairing1 * gu20 + 2.0 * pairing2 * gu21 + 2.0 * E00 * gu00 +
649 2.0 * E11 * gu11 + 2.0 * E22 * gu22 );
654 if ( diagonal_ || ( treat_boundary && at_boundary ) )
656 for (
int dim_diagBC = 0; dim_diagBC < 3; ++dim_diagBC )
659 for (
int node_idx = surface_shift; node_idx < 6 - cmb_shift; ++node_idx )
661 const double gx = dN_ref[node_idx][0];
662 const double gy = dN_ref[node_idx][1];
663 const double gz = dN_ref[node_idx][2];
665 const double g0 = i00 * gx + i01 * gy + i02 * gz;
666 const double g1 = i10 * gx + i11 * gy + i12 * gz;
667 const double g2 = i20 * gx + i21 * gy + i22 * gz;
669 double E00, E11, E22, sym01, sym02, sym12, gdd;
670 column_grad_to_sym( dim_diagBC, g0, g1, g2, E00, E11, E22, sym01, sym02, sym12, gdd );
672 const int ddx = WEDGE_NODE_OFF[w][node_idx][0];
673 const int ddy = WEDGE_NODE_OFF[w][node_idx][1];
674 const int ddr = WEDGE_NODE_OFF[w][node_idx][2];
676 const int nid = node_id( tx + ddx, ty + ddy );
677 const int lvl = lvl0 + ddr;
679 const double s = src_sh( nid, dim_diagBC, lvl );
681 const double pairing0 = 4.0 * s;
682 const double pairing1 = 2.0 * s;
684 const int u = WEDGE_TO_UNIQUE[w][node_idx];
686 dst8[dim_diagBC][u] +=
687 kwJ * ( pairing0 * ( sym01 * sym01 ) + pairing0 * ( sym02 * sym02 ) +
688 pairing0 * ( sym12 * sym12 ) + pairing1 * ( E00 * E00 ) + pairing1 * ( E11 * E11 ) +
689 pairing1 * ( E22 * E22 ) + NEG_TWO_THIRDS * ( gdd * gdd ) * s );
696 for (
int dim_add = 0; dim_add < 3; ++dim_add )
698 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell, r_cell, dim_add ), dst8[dim_add][0] );
699 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell, r_cell, dim_add ), dst8[dim_add][1] );
700 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell + 1, r_cell, dim_add ), dst8[dim_add][2] );
701 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell, r_cell + 1, dim_add ), dst8[dim_add][3] );
702 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell, r_cell + 1, dim_add ), dst8[dim_add][4] );
703 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell + 1, r_cell + 1, dim_add ), dst8[dim_add][5] );
704 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell + 1, r_cell, dim_add ), dst8[dim_add][6] );
705 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell + 1, r_cell + 1, dim_add ), dst8[dim_add][7] );