Loading...
Searching...
No Matches
epsilon_divdiv_kerngen_v07_split_paths.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "../../../quadrature/quadrature.hpp"
6#include "dense/vec.hpp"
10#include "impl/Kokkos_Profiling.hpp"
11#include "linalg/operator.hpp"
14#include "linalg/vector.hpp"
15#include "linalg/vector_q1.hpp"
16#include "util/timer.hpp"
17
19
30
31/**
32 * @brief Matrix-free / matrix-based epsilon-div-div operator on wedge elements in a spherical shell.
33 *
34 * This functor supports two execution modes:
35 *
36 * 1) FAST PATH (shared-memory fused local matvec)
37 * - Used when both shell boundaries are only DIRICHLET/NEUMANN
38 * - Used only when no stored local matrices are enabled
39 * - This is the high-throughput path that loads a tile slab into team scratch memory
40 *
41 * 2) SLOW PATH (local matrix path)
42 * - Used when FREESLIP is present on either boundary, because the local basis needs
43 * a per-boundary-node rotation (normal/tangential transform)
44 * - Used whenever local matrices are stored (full/selective)
45 *
46 * The path decision is computed on the host and cached in `use_slow_path_`, so the kernel
47 * launch itself is specialized (we do not branch per thread/team inside the hot path).
48 */
49template < typename ScalarT, int VecDim = 3 >
51{
52 public:
55 using ScalarType = ScalarT;
56 static constexpr int LocalMatrixDim = 18;
60 using Team = Kokkos::TeamPolicy<>::member_type;
61
62 private:
63 // Optional storage for element-local matrices (used by GCA/coarsening or explicit local mat path)
64 LocalMatrixStorage local_matrix_storage_;
65
66 // Domain and geometry / coefficient data
68 grid::Grid3DDataVec< ScalarT, 3 > grid_; // lateral shell geometry (unit sphere coords)
69 grid::Grid2DDataScalar< ScalarT > radii_; // radial coordinates per local subdomain
70 grid::Grid4DDataScalar< ScalarType > k_; // scalar coefficient field
71 grid::Grid4DDataScalar< grid::shell::ShellBoundaryFlag > mask_; // boundary flags per cell/node
72 BoundaryConditions bcs_; // CMB/SURFACE BC types
73
74 bool diagonal_; // if true, apply only diagonal of local operator (or diagonalized fast approx)
75
76 linalg::OperatorApplyMode operator_apply_mode_;
77 linalg::OperatorCommunicationMode operator_communication_mode_;
78 linalg::OperatorStoredMatrixMode operator_stored_matrix_mode_;
79
82
83 // Views captured by device kernels during apply_impl
86
87 // Quadrature data (Felippa 1x1 on wedge)
88 const int num_quad_points = quadrature::quad_felippa_1x1_num_quad_points;
91
92 // Domain extents (cells) for one local subdomain
93 int local_subdomains_;
94 int hex_lat_;
95 int hex_rad_;
96 int lat_refinement_level_;
97
98 // 3D tile decomposition for TeamPolicy
99 // A team handles a slab: lat_tile_ x lat_tile_ x r_tile_ cells
100 int lat_tile_;
101 int r_tile_;
102 int lat_tiles_;
103 int r_tiles_;
104 int team_size_;
105 int blocks_;
106
107 ScalarT r_max_;
108 ScalarT r_min_;
109
110 /**
111 * @brief Cached host-side dispatch flag.
112 *
113 * true -> launch slow kernel
114 * false -> launch fast kernel
115 *
116 * Recomputed whenever boundary conditions or stored-matrix mode changes.
117 */
118 bool use_slow_path_ = false;
119
120 private:
121 /**
122 * @brief Recompute whether the operator must use the slow path.
123 *
124 * Slow path is required when:
125 * - local matrices are stored (full/selective), OR
126 * - either shell boundary uses FREESLIP (needs local basis rotation)
127 *
128 * This function is intended to be called on the host only (constructor/setters).
129 */
130 void update_kernel_path_flag_host_only()
131 {
132 const BoundaryConditionFlag cmb_bc = get_boundary_condition_flag( bcs_, CMB );
133 const BoundaryConditionFlag surface_bc = get_boundary_condition_flag( bcs_, SURFACE );
134
135 const bool has_freeslip_bc = ( cmb_bc == FREESLIP ) || ( surface_bc == FREESLIP );
136 const bool has_stored_matrices = ( operator_stored_matrix_mode_ != linalg::OperatorStoredMatrixMode::Off );
137
138 use_slow_path_ = has_freeslip_bc || has_stored_matrices;
139 }
140
141 public:
143 const grid::shell::DistributedDomain& domain,
148 BoundaryConditions bcs,
149 bool diagonal,
151 linalg::OperatorCommunicationMode operator_communication_mode =
154 : domain_( domain )
155 , grid_( grid )
156 , radii_( radii )
157 , k_( k )
158 , mask_( mask )
159 , diagonal_( diagonal )
160 , operator_apply_mode_( operator_apply_mode )
161 , operator_communication_mode_( operator_communication_mode )
162 , operator_stored_matrix_mode_( operator_stored_matrix_mode )
163 , recv_buffers_( domain )
164 , comm_plan_( domain )
165 {
166 bcs_[0] = bcs[0];
167 bcs_[1] = bcs[1];
168
171
172 const grid::shell::DomainInfo& domain_info = domain_.domain_info();
173 local_subdomains_ = domain_.subdomains().size();
174 hex_lat_ = domain_info.subdomain_num_nodes_per_side_laterally() - 1;
175 hex_rad_ = domain_info.subdomain_num_nodes_radially() - 1;
176 lat_refinement_level_ = domain_info.diamond_lateral_refinement_level();
177
178 // Tile dimensions: tune for backend occupancy / scratch usage
179 lat_tile_ = 4;
180 r_tile_ = 8;
181
182 lat_tiles_ = ( hex_lat_ + lat_tile_ - 1 ) / lat_tile_;
183 r_tiles_ = ( hex_rad_ + r_tile_ - 1 ) / r_tile_;
184
185 team_size_ = lat_tile_ * lat_tile_ * r_tile_;
186 blocks_ = local_subdomains_ * lat_tiles_ * lat_tiles_ * r_tiles_;
187
188 r_min_ = domain_info.radii()[0];
189 r_max_ = domain_info.radii()[domain_info.radii().size() - 1];
190
191 update_kernel_path_flag_host_only();
192
193 util::logroot << "[EpsilonDivDiv] tile size (x,y,r)=(" << lat_tile_ << "," << lat_tile_ << "," << r_tile_ << ")"
194 << std::endl;
195 util::logroot << "[EpsilonDivDiv] number of tiles (x,y,r)=(" << lat_tiles_ << "," << lat_tiles_ << ","
196 << r_tiles_ << "), team_size=" << team_size_ << ", blocks=" << blocks_ << std::endl;
197 util::logroot << "[EpsilonDivDiv] kernel path = " << ( use_slow_path_ ? "slow" : "fast" ) << std::endl;
198 }
199
201 const linalg::OperatorApplyMode operator_apply_mode,
202 const linalg::OperatorCommunicationMode operator_communication_mode )
203 {
204 operator_apply_mode_ = operator_apply_mode;
205 operator_communication_mode_ = operator_communication_mode;
206 }
207
208 void set_diagonal( bool v ) { diagonal_ = v; }
209
210 /// Optional runtime BC update; also refreshes fast/slow dispatch decision.
211 void set_boundary_conditions( BoundaryConditions bcs )
212 {
213 bcs_[0] = bcs[0];
214 bcs_[1] = bcs[1];
215 update_kernel_path_flag_host_only();
216 }
217
219 const grid::shell::DistributedDomain& get_domain() const { return domain_; }
222
223 /// Convenience wrapper for shell boundary mask checks.
224 KOKKOS_INLINE_FUNCTION
226 const int local_subdomain_id,
227 const int x_cell,
228 const int y_cell,
229 const int r_cell,
231 {
232 return util::has_flag( mask_( local_subdomain_id, x_cell, y_cell, r_cell ), flag );
233 }
234
235 /**
236 * @brief Configure local matrix storage mode (Off / Selective / Full).
237 *
238 * This may allocate storage immediately and changes the kernel dispatch:
239 * any non-Off mode forces the slow path.
240 */
242 linalg::OperatorStoredMatrixMode operator_stored_matrix_mode,
243 int level_range,
245 {
246 operator_stored_matrix_mode_ = operator_stored_matrix_mode;
247
248 if ( operator_stored_matrix_mode_ != linalg::OperatorStoredMatrixMode::Off )
249 {
251 domain_, operator_stored_matrix_mode_, level_range, GCAElements );
252 }
253
254 update_kernel_path_flag_host_only();
255 }
256
257 linalg::OperatorStoredMatrixMode get_stored_matrix_mode() { return operator_stored_matrix_mode_; }
258
259 /// Store a local element matrix (used in GCA/coarsening workflows).
260 KOKKOS_INLINE_FUNCTION
262 const int local_subdomain_id,
263 const int x_cell,
264 const int y_cell,
265 const int r_cell,
266 const int wedge,
268 {
269 KOKKOS_ASSERT( operator_stored_matrix_mode_ != linalg::OperatorStoredMatrixMode::Off );
270 local_matrix_storage_.set_matrix( local_subdomain_id, x_cell, y_cell, r_cell, wedge, mat );
271 }
272
273 /**
274 * @brief Get a local matrix, either from storage or assembled on-the-fly.
275 */
276 KOKKOS_INLINE_FUNCTION
278 const int local_subdomain_id,
279 const int x_cell,
280 const int y_cell,
281 const int r_cell,
282 const int wedge ) const
283 {
284 if ( operator_stored_matrix_mode_ != linalg::OperatorStoredMatrixMode::Off )
285 {
286 if ( !local_matrix_storage_.has_matrix( local_subdomain_id, x_cell, y_cell, r_cell, wedge ) )
287 {
288 Kokkos::abort( "No matrix found at that spatial index." );
289 }
290 return local_matrix_storage_.get_matrix( local_subdomain_id, x_cell, y_cell, r_cell, wedge );
291 }
292 return assemble_local_matrix( local_subdomain_id, x_cell, y_cell, r_cell, wedge );
293 }
294
295 /**
296 * @brief Apply operator to src and accumulate/replace into dst.
297 *
298 * Key design point:
299 * - The slow/fast path decision is made here (host side), before kernel launch.
300 * - This avoids branching on boundary/storage mode in the hot kernel body.
301 */
302 void apply_impl( const SrcVectorType& src, DstVectorType& dst )
303 {
304 util::Timer timer_apply( "epsilon_divdiv_apply" );
305
306 if ( operator_apply_mode_ == linalg::OperatorApplyMode::Replace )
307 {
308 assign( dst, 0 );
309 }
310
311 // Cache input/output grid views into members so the device functor sees them.
312 dst_ = dst.grid_data();
313 src_ = src.grid_data();
314
315 util::Timer timer_kernel( "epsilon_divdiv_kernel" );
316 Kokkos::TeamPolicy<> policy( blocks_, team_size_ );
317
318 // Fast path uses team scratch; slow path does not need it.
319 if ( !use_slow_path_ )
320 {
321 policy.set_scratch_size( 0, Kokkos::PerTeam( team_shmem_size( team_size_ ) ) );
322 }
323
324 // Host-side dispatch to specialized kernel
325 if ( use_slow_path_ )
326 {
327 Kokkos::parallel_for(
328 "matvec_slow",
329 policy,
330 KOKKOS_CLASS_LAMBDA( const Team& team ) {
331 this->operator_slow_kernel( team );
332 } );
333 }
334 else
335 {
336 Kokkos::parallel_for(
337 "matvec_fast",
338 policy,
339 KOKKOS_CLASS_LAMBDA( const Team& team ) {
340 this->operator_fast_kernel( team );
341 } );
342 }
343
344 Kokkos::fence();
345 timer_kernel.stop();
346
347 if ( operator_communication_mode_ == linalg::OperatorCommunicationMode::CommunicateAdditively )
348 {
349 util::Timer timer_comm( "epsilon_divdiv_comm" );
350 terra::communication::shell::send_recv_with_plan( comm_plan_, dst_, recv_buffers_ );
351 }
352 }
353
354 /**
355 * @brief Convert one gradient column (for vector component dim) into the symmetric-gradient entries.
356 *
357 * The operator is assembled/applied using sym(grad u). For basis vector component `dim`,
358 * only one gradient column is populated, so this helper computes:
359 * - diagonal symmetric entries E00,E11,E22
360 * - off-diagonal symmetric entries sym01,sym02,sym12
361 * - divergence contribution gdd (= corresponding diagonal gradient entry)
362 */
363 KOKKOS_INLINE_FUNCTION
365 const int dim,
366 const double g0,
367 const double g1,
368 const double g2,
369 double& E00,
370 double& E11,
371 double& E22,
372 double& sym01,
373 double& sym02,
374 double& sym12,
375 double& gdd ) const
376 {
377 E00 = E11 = E22 = sym01 = sym02 = sym12 = gdd = 0.0;
378
379 switch ( dim )
380 {
381 case 0:
382 E00 = g0;
383 gdd = g0;
384 sym01 = 0.5 * g1;
385 sym02 = 0.5 * g2;
386 break;
387 case 1:
388 E11 = g1;
389 gdd = g1;
390 sym01 = 0.5 * g0;
391 sym12 = 0.5 * g2;
392 break;
393 default:
394 E22 = g2;
395 gdd = g2;
396 sym02 = 0.5 * g0;
397 sym12 = 0.5 * g1;
398 break;
399 }
400 }
401
402 /**
403 * @brief Team scratch requirement for the fast path.
404 *
405 * Per team we store a slab covering:
406 * - (lat_tile_+1)x(lat_tile_+1) surface geometry nodes
407 * - (r_tile_+1) radial levels
408 * - src dofs and coefficient values for the slab
409 */
410 KOKKOS_INLINE_FUNCTION
411 size_t team_shmem_size( const int /*ts*/ ) const
412 {
413 const int nlev = r_tile_ + 1;
414 const int n = lat_tile_ + 1;
415 const int nxy = n * n;
416
417 const size_t nscalars =
418 size_t( nxy ) * 3 + // coords_sh
419 size_t( nxy ) * 3 * nlev + // src_sh
420 size_t( nxy ) * nlev + // k_sh
421 size_t( nlev ) + // r_sh
422 1;
423
424 return sizeof( ScalarType ) * nscalars;
425 }
426
427 private:
428 /**
429 * @brief Decode TeamPolicy league/team rank into subdomain/tile/cell indices.
430 *
431 * Mapping:
432 * league_rank -> (local_subdomain_id, lat_x_tile, lat_y_tile, r_tile)
433 * team_rank -> (tx, ty, tr) within the tile
434 */
435 KOKKOS_INLINE_FUNCTION
436 void decode_team_indices(
437 const Team& team,
438 int& local_subdomain_id,
439 int& x0,
440 int& y0,
441 int& r0,
442 int& tx,
443 int& ty,
444 int& tr,
445 int& x_cell,
446 int& y_cell,
447 int& r_cell ) const
448 {
449 int tmp = team.league_rank();
450
451 const int r_tile_id = tmp % r_tiles_;
452 tmp /= r_tiles_;
453
454 const int lat_y_id = tmp % lat_tiles_;
455 tmp /= lat_tiles_;
456
457 const int lat_x_id = tmp % lat_tiles_;
458 tmp /= lat_tiles_;
459
460 local_subdomain_id = tmp;
461
462 x0 = lat_x_id * lat_tile_;
463 y0 = lat_y_id * lat_tile_;
464 r0 = r_tile_id * r_tile_;
465
466 const int tid = team.team_rank();
467 tx = tid % lat_tile_;
468 ty = ( tid / lat_tile_ ) % lat_tile_;
469 tr = tid / ( lat_tile_ * lat_tile_ );
470
471 x_cell = x0 + tx;
472 y_cell = y0 + ty;
473 r_cell = r0 + tr;
474 }
475
476 /**
477 * @brief Device wrapper for slow path launch.
478 *
479 * Only does index decode + boundary flag fetch + dispatch into operator_slow_path().
480 */
481 KOKKOS_INLINE_FUNCTION
482 void operator_slow_kernel( const Team& team ) const
483 {
484 int local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell;
485 decode_team_indices( team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell );
486
487 if ( tr >= r_tile_ )
488 return;
489
490 const bool at_cmb = has_flag( local_subdomain_id, x_cell, y_cell, r_cell, CMB );
491 const bool at_surface = has_flag( local_subdomain_id, x_cell, y_cell, r_cell + 1, SURFACE );
492
493 operator_slow_path(
494 team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell, at_cmb, at_surface );
495 }
496
497 /**
498 * @brief Device wrapper for fast path launch.
499 *
500 * Only does index decode + boundary flag fetch + dispatch into operator_fast_path().
501 */
502 KOKKOS_INLINE_FUNCTION
503 void operator_fast_kernel( const Team& team ) const
504 {
505 int local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell;
506 decode_team_indices( team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell );
507
508 if ( tr >= r_tile_ )
509 return;
510
511 const bool at_cmb = has_flag( local_subdomain_id, x_cell, y_cell, r_cell, CMB );
512 const bool at_surface = has_flag( local_subdomain_id, x_cell, y_cell, r_cell + 1, SURFACE );
513
514 operator_fast_path(
515 team, local_subdomain_id, x0, y0, r0, tx, ty, tr, x_cell, y_cell, r_cell, at_cmb, at_surface );
516 }
517
518 /**
519 * @brief Slow path: local matrix-based application.
520 *
521 * Used for:
522 * - FREESLIP (normal/tangential transform required)
523 * - stored matrix mode (Full/Selective)
524 *
525 * Steps per element:
526 * 1) fetch/assemble wedge-local matrices A[2]
527 * 2) gather local src dofs
528 * 3) apply boundary treatment:
529 * - DIRICHLET: zero corresponding rows/cols via boundary mask
530 * - FREESLIP: rotate boundary dofs into n/t coordinates, apply mask, rotate back
531 * - NEUMANN: no modification
532 * 4) local matvec and atomic scatter to global dst
533 */
534 KOKKOS_INLINE_FUNCTION
535 void operator_slow_path(
536 const Team& team,
537 const int local_subdomain_id,
538 const int x0,
539 const int y0,
540 const int r0,
541 const int tx,
542 const int ty,
543 const int tr,
544 const int x_cell,
545 const int y_cell,
546 const int r_cell,
547 const bool at_cmb,
548 const bool at_surface ) const
549 {
550 // Unused in slow path body (kept in signature for symmetry with fast path)
551 (void)team;
552 (void)x0;
553 (void)y0;
554 (void)r0;
555 (void)tx;
556 (void)ty;
557 (void)tr;
558
559 if ( x_cell >= hex_lat_ || y_cell >= hex_lat_ || r_cell >= hex_rad_ )
560 return;
561
562 // ---- local matrix acquisition (stored or on-the-fly) ----
564
565 if ( operator_stored_matrix_mode_ == linalg::OperatorStoredMatrixMode::Full )
566 {
567 A[0] = local_matrix_storage_.get_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 0 );
568 A[1] = local_matrix_storage_.get_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 1 );
569 }
570 else if ( operator_stored_matrix_mode_ == linalg::OperatorStoredMatrixMode::Selective )
571 {
572 if ( local_matrix_storage_.has_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 0 ) &&
573 local_matrix_storage_.has_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 1 ) )
574 {
575 A[0] = local_matrix_storage_.get_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 0 );
576 A[1] = local_matrix_storage_.get_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 1 );
577 }
578 else
579 {
580 A[0] = assemble_local_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 0 );
581 A[1] = assemble_local_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 1 );
582 }
583 }
584 else
585 {
586 A[0] = assemble_local_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 0 );
587 A[1] = assemble_local_matrix( local_subdomain_id, x_cell, y_cell, r_cell, 1 );
588 }
589
590 // ---- gather local source dofs (dimension-wise layout) ----
591 dense::Vec< ScalarT, 18 > src[num_wedges_per_hex_cell];
592 for ( int dimj = 0; dimj < 3; dimj++ )
593 {
594 dense::Vec< ScalarT, 6 > src_d[num_wedges_per_hex_cell];
595 extract_local_wedge_vector_coefficients( src_d, local_subdomain_id, x_cell, y_cell, r_cell, dimj, src_ );
596
597 for ( int wedge = 0; wedge < num_wedges_per_hex_cell; wedge++ )
598 {
599 for ( int i = 0; i < num_nodes_per_wedge; i++ )
600 {
601 src[wedge]( dimj * num_nodes_per_wedge + i ) = src_d[wedge]( i );
602 }
603 }
604 }
605
606 // Boundary masking is applied multiplicatively to local matrices.
607 // Starts as all ones; DIRICHLET/FREESLIP zero out constrained couplings.
608 dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > boundary_mask;
609 boundary_mask.fill( 1.0 );
610
611 bool freeslip_reorder = false;
612 dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > R[num_wedges_per_hex_cell];
613
614 if ( at_cmb || at_surface )
615 {
616 ShellBoundaryFlag sbf = at_cmb ? CMB : SURFACE;
618
619 if ( bcf == DIRICHLET )
620 {
621 // Zero couplings involving constrained boundary nodes.
622 for ( int dimi = 0; dimi < 3; ++dimi )
623 {
624 for ( int dimj = 0; dimj < 3; ++dimj )
625 {
626 for ( int i = 0; i < num_nodes_per_wedge; i++ )
627 {
628 for ( int j = 0; j < num_nodes_per_wedge; j++ )
629 {
630 if ( ( at_cmb && ( ( dimi == dimj && i != j && ( i < 3 || j < 3 ) ) ||
631 ( dimi != dimj && ( i < 3 || j < 3 ) ) ) ) ||
632 ( at_surface && ( ( dimi == dimj && i != j && ( i >= 3 || j >= 3 ) ) ||
633 ( dimi != dimj && ( i >= 3 || j >= 3 ) ) ) ) )
634 {
635 boundary_mask(
636 i + dimi * num_nodes_per_wedge,
637 j + dimj * num_nodes_per_wedge ) = 0.0;
638 }
639 }
640 }
641 }
642 }
643 }
644 else if ( bcf == FREESLIP )
645 {
646 // FREESLIP treatment:
647 // rotate boundary dofs (per node) to [normal, tangential1, tangential2],
648 // constrain the normal component, then rotate back.
649 freeslip_reorder = true;
650 dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > A_tmp[num_wedges_per_hex_cell] = { 0 };
651
652 // Reorder A into node-wise layout because FREESLIP rotation is node-local 3x3 blocks.
653 for ( int wedge = 0; wedge < 2; ++wedge )
654 {
655 for ( int dimi = 0; dimi < 3; ++dimi )
656 {
657 for ( int node_idxi = 0; node_idxi < num_nodes_per_wedge; node_idxi++ )
658 {
659 for ( int dimj = 0; dimj < 3; ++dimj )
660 {
661 for ( int node_idxj = 0; node_idxj < num_nodes_per_wedge; node_idxj++ )
662 {
663 A_tmp[wedge]( node_idxi * 3 + dimi, node_idxj * 3 + dimj ) = A[wedge](
664 node_idxi + dimi * num_nodes_per_wedge,
665 node_idxj + dimj * num_nodes_per_wedge );
666 }
667 }
668 }
669 }
671 }
672
673 // Local node offsets (in the hex cell) for the 3 nodes on the boundary face of each wedge
674 constexpr int layer_hex_offset_x[2][3] = { { 0, 1, 0 }, { 1, 0, 1 } };
675 constexpr int layer_hex_offset_y[2][3] = { { 0, 0, 1 }, { 1, 1, 0 } };
676
677 for ( int wedge = 0; wedge < 2; ++wedge )
678 {
679 // Start as identity
680 for ( int i = 0; i < LocalMatrixDim; ++i ) { R[wedge]( i, i ) = 1.0; }
681
682 // Build 3x3 rotation blocks for the boundary nodes
683 for ( int boundary_node_idx = 0; boundary_node_idx < 3; boundary_node_idx++ )
684 {
685 dense::Vec< double, 3 > normal = grid::shell::coords(
686 local_subdomain_id,
687 x_cell + layer_hex_offset_x[wedge][boundary_node_idx],
688 y_cell + layer_hex_offset_y[wedge][boundary_node_idx],
689 r_cell + ( at_cmb ? 0 : 1 ),
690 grid_,
691 radii_ );
692
693 auto R_i = trafo_mat_cartesian_to_normal_tangential( normal );
694
695 const int offset_in_R = at_cmb ? 0 : 9; // first 3 nodes (CMB) or top 3 nodes (SURFACE)
696 for ( int dimi = 0; dimi < 3; ++dimi )
697 {
698 for ( int dimj = 0; dimj < 3; ++dimj )
699 {
700 R[wedge](
701 offset_in_R + boundary_node_idx * 3 + dimi,
702 offset_in_R + boundary_node_idx * 3 + dimj ) = R_i( dimi, dimj );
703 }
704 }
705 }
706
707 // Rotate matrix and vector into n/t frame
708 A[wedge] = R[wedge] * A_tmp[wedge] * R[wedge].transposed();
709
710 auto src_tmp = R[wedge] * src[wedge];
711 for ( int i = 0; i < 18; ++i ) { src[wedge]( i ) = src_tmp( i ); }
712
713 // Constrain only the normal component at boundary nodes (node-wise layout => index = node*3 + 0)
714 const int node_start = at_surface ? 3 : 0;
715 const int node_end = at_surface ? 6 : 3;
716
717 for ( int node_idx = node_start; node_idx < node_end; node_idx++ )
718 {
719 const int idx = node_idx * 3; // normal component
720 for ( int k = 0; k < 18; ++k )
721 {
722 if ( k != idx )
723 {
724 boundary_mask( idx, k ) = 0.0;
725 boundary_mask( k, idx ) = 0.0;
726 }
727 }
728 }
729 }
730 }
731 else if ( bcf == NEUMANN )
732 {
733 // Natural BC => no extra masking
734 }
735 }
736
737 // Apply boundary masking to local matrices
738 for ( int wedge = 0; wedge < num_wedges_per_hex_cell; wedge++ )
739 {
740 A[wedge].hadamard_product( boundary_mask );
741 }
742
743 if ( diagonal_ )
744 {
745 A[0] = A[0].diagonal();
746 A[1] = A[1].diagonal();
747 }
748
749 // Local matvec
750 dense::Vec< ScalarT, LocalMatrixDim > dst[num_wedges_per_hex_cell];
751 dst[0] = A[0] * src[0];
752 dst[1] = A[1] * src[1];
753
754 // Rotate back from n/t frame and reorder back to dimension-wise layout
755 if ( freeslip_reorder )
756 {
757 dense::Vec< ScalarT, LocalMatrixDim > dst_tmp[num_wedges_per_hex_cell];
758 dst_tmp[0] = R[0].transposed() * dst[0];
759 dst_tmp[1] = R[1].transposed() * dst[1];
760
761 for ( int i = 0; i < 18; ++i )
762 {
763 dst[0]( i ) = dst_tmp[0]( i );
764 dst[1]( i ) = dst_tmp[1]( i );
765 }
766
769 }
770
771 // Scatter back to global vector (atomic because neighboring cells share nodes)
772 for ( int dimi = 0; dimi < 3; dimi++ )
773 {
774 dense::Vec< ScalarT, 6 > dst_d[num_wedges_per_hex_cell];
775 dst_d[0] = dst[0].template slice< 6 >( dimi * num_nodes_per_wedge );
776 dst_d[1] = dst[1].template slice< 6 >( dimi * num_nodes_per_wedge );
777
779 dst_, local_subdomain_id, x_cell, y_cell, r_cell, dimi, dst_d );
780 }
781 }
782
783 /**
784 * @brief Fast path: fused shared-memory matrix-free local matvec.
785 *
786 * Assumptions:
787 * - no stored matrices
788 * - no FREESLIP boundaries
789 *
790 * Team-level workflow:
791 * 1) cooperatively load a (lat_tile_+1)x(lat_tile_+1)x(r_tile_+1) slab into scratch
792 * 2) each thread computes one hex cell in the tile
793 * 3) each cell is split into 2 wedges and integrated in fused form
794 * 4) atomic scatter to dst_
795 *
796 * Boundary handling in this path supports:
797 * - DIRICHLET (via local node-range shifts / diagonal treatment)
798 * - NEUMANN (natural, no modification)
799 */
800 KOKKOS_INLINE_FUNCTION
801 void operator_fast_path(
802 const Team& team,
803 const int local_subdomain_id,
804 const int x0,
805 const int y0,
806 const int r0,
807 const int tx,
808 const int ty,
809 const int tr,
810 const int x_cell,
811 const int y_cell,
812 const int r_cell,
813 const bool at_cmb,
814 const bool at_surface ) const
815 {
816 // ---- scratch slab dimensions ----
817 const int nlev = r_tile_ + 1;
818 const int nxy = ( lat_tile_ + 1 ) * ( lat_tile_ + 1 );
819
820 // Team scratch memory layout:
821 // [coords_sh | src_sh | k_sh | r_sh]
822 double* shmem = reinterpret_cast< double* >(
823 team.team_shmem().get_shmem( team_shmem_size( team.team_size() ) ) );
824
825 using ScratchCoords =
826 Kokkos::View< double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
827 using ScratchSrc =
828 Kokkos::View< double***, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
829 using ScratchK =
830 Kokkos::View< double**, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >;
831
832 ScratchCoords coords_sh( shmem, nxy, 3 );
833 shmem += nxy * 3;
834
835 ScratchSrc src_sh( shmem, nxy, 3, nlev );
836 shmem += nxy * 3 * nlev;
837
838 ScratchK k_sh( shmem, nxy, nlev );
839 shmem += nxy * nlev;
840
841 auto r_sh =
842 Kokkos::View< double*, Kokkos::LayoutRight, typename Team::scratch_memory_space, Kokkos::MemoryUnmanaged >(
843 shmem, nlev );
844
845 auto node_id = [&]( int nx, int ny ) -> int { return nx + ( lat_tile_ + 1 ) * ny; };
846
847 // ---- cooperative tile loads ----
848 // surface geometry coords
849 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, nxy ), [&]( int n ) {
850 const int dxn = n % ( lat_tile_ + 1 );
851 const int dyn = n / ( lat_tile_ + 1 );
852 const int xi = x0 + dxn;
853 const int yi = y0 + dyn;
854
855 if ( xi <= hex_lat_ && yi <= hex_lat_ )
856 {
857 coords_sh( n, 0 ) = grid_( local_subdomain_id, xi, yi, 0 );
858 coords_sh( n, 1 ) = grid_( local_subdomain_id, xi, yi, 1 );
859 coords_sh( n, 2 ) = grid_( local_subdomain_id, xi, yi, 2 );
860 }
861 else
862 {
863 coords_sh( n, 0 ) = coords_sh( n, 1 ) = coords_sh( n, 2 ) = 0.0;
864 }
865 } );
866
867 // radial coordinates
868 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, nlev ), [&]( int lvl ) {
869 const int rr = r0 + lvl;
870 r_sh( lvl ) = ( rr <= hex_rad_ ) ? radii_( local_subdomain_id, rr ) : 0.0;
871 } );
872
873 // coefficient and source values
874 const int total_pairs = nxy * nlev;
875 Kokkos::parallel_for( Kokkos::TeamThreadRange( team, total_pairs ), [&]( int t ) {
876 const int lvl = t / nxy;
877 const int node = t - lvl * nxy;
878
879 const int dxn = node % ( lat_tile_ + 1 );
880 const int dyn = node / ( lat_tile_ + 1 );
881
882 const int xi = x0 + dxn;
883 const int yi = y0 + dyn;
884 const int rr = r0 + lvl;
885
886 if ( xi <= hex_lat_ && yi <= hex_lat_ && rr <= hex_rad_ )
887 {
888 k_sh( node, lvl ) = k_( local_subdomain_id, xi, yi, rr );
889 src_sh( node, 0, lvl ) = src_( local_subdomain_id, xi, yi, rr, 0 );
890 src_sh( node, 1, lvl ) = src_( local_subdomain_id, xi, yi, rr, 1 );
891 src_sh( node, 2, lvl ) = src_( local_subdomain_id, xi, yi, rr, 2 );
892 }
893 else
894 {
895 k_sh( node, lvl ) = 0.0;
896 src_sh( node, 0, lvl ) = src_sh( node, 1, lvl ) = src_sh( node, 2, lvl ) = 0.0;
897 }
898 } );
899
900 team.team_barrier();
901
902 // Each logical thread computes one hex cell in the tile
903 if ( x_cell >= hex_lat_ || y_cell >= hex_lat_ || r_cell >= hex_rad_ )
904 return;
905
906 const int lvl0 = tr;
907 const double r_0 = r_sh( lvl0 );
908 const double r_1 = r_sh( lvl0 + 1 );
909
910 // In the fast path we only treat DIRICHLET specially.
911 // FREESLIP is excluded by host-side dispatch.
912 const bool at_boundary = at_cmb || at_surface;
913 bool treat_boundary_dirichlet = false;
914 if ( at_boundary )
915 {
916 const ShellBoundaryFlag sbf = at_cmb ? CMB : SURFACE;
917 treat_boundary_dirichlet = ( get_boundary_condition_flag( bcs_, sbf ) == DIRICHLET );
918 }
919
920 // For full (non-diagonal) application, skip constrained face nodes.
921 // For diagonal mode, the diagonal term can still be accumulated with boundary handling below.
922 const int cmb_shift = ( ( at_boundary && treat_boundary_dirichlet && ( !diagonal_ ) && at_cmb ) ? 3 : 0 );
923 const int surface_shift =
924 ( ( at_boundary && treat_boundary_dirichlet && ( !diagonal_ ) && at_surface ) ? 3 : 0 );
925
926 // Wedge connectivity / local constants
927 static constexpr int WEDGE_NODE_OFF[2][6][3] = {
928 { { 0, 0, 0 }, { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 }, { 1, 0, 1 }, { 0, 1, 1 } },
929 { { 1, 1, 0 }, { 0, 1, 0 }, { 1, 0, 0 }, { 1, 1, 1 }, { 0, 1, 1 }, { 1, 0, 1 } } };
930
931 // Map wedge-local nodes into the 8 unique hex nodes used for scatter accumulation
932 static constexpr int WEDGE_TO_UNIQUE[2][6] = {
933 { 0, 1, 2, 3, 4, 5 }, // wedge 0
934 { 6, 2, 1, 7, 5, 4 } // wedge 1
935 };
936
937 constexpr double ONE_THIRD = 1.0 / 3.0;
938 constexpr double ONE_SIXTH = 1.0 / 6.0;
939 constexpr double NEG_TWO_THIRDS = -0.66666666666666663;
940
941 // Reference gradients of wedge basis functions at the single Felippa 1x1 point
942 static constexpr double dN_ref[6][3] = {
943 { -0.5, -0.5, -ONE_SIXTH },
944 { 0.5, 0.0, -ONE_SIXTH },
945 { 0.0, 0.5, -ONE_SIXTH },
946 { -0.5, -0.5, ONE_SIXTH },
947 { 0.5, 0.0, ONE_SIXTH },
948 { 0.0, 0.5, ONE_SIXTH } };
949
950 // Four lateral nodes of the hex footprint inside this tile
951 const int n00 = node_id( tx, ty );
952 const int n01 = node_id( tx, ty + 1 );
953 const int n10 = node_id( tx + 1, ty );
954 const int n11 = node_id( tx + 1, ty + 1 );
955
956 // Surface coordinates for the 2 wedges (3 points each)
957 double ws[2][3][3];
958
959 // wedge 0: (q00,q10,q01)
960 ws[0][0][0] = coords_sh( n00, 0 );
961 ws[0][0][1] = coords_sh( n00, 1 );
962 ws[0][0][2] = coords_sh( n00, 2 );
963 ws[0][1][0] = coords_sh( n10, 0 );
964 ws[0][1][1] = coords_sh( n10, 1 );
965 ws[0][1][2] = coords_sh( n10, 2 );
966 ws[0][2][0] = coords_sh( n01, 0 );
967 ws[0][2][1] = coords_sh( n01, 1 );
968 ws[0][2][2] = coords_sh( n01, 2 );
969
970 // wedge 1: (q11,q01,q10)
971 ws[1][0][0] = coords_sh( n11, 0 );
972 ws[1][0][1] = coords_sh( n11, 1 );
973 ws[1][0][2] = coords_sh( n11, 2 );
974 ws[1][1][0] = coords_sh( n01, 0 );
975 ws[1][1][1] = coords_sh( n01, 1 );
976 ws[1][1][2] = coords_sh( n01, 2 );
977 ws[1][2][0] = coords_sh( n10, 0 );
978 ws[1][2][1] = coords_sh( n10, 1 );
979 ws[1][2][2] = coords_sh( n10, 2 );
980
981 // Per-thread accumulation into 8 hex nodes x 3 vector components.
982 // We accumulate locally in registers and atomically scatter once at the end.
983 double dst8[3][8] = { 0.0 };
984
985 for ( int w = 0; w < 2; ++w )
986 {
987 // Coefficient k evaluated at quadrature point via average of 6 wedge nodes for 1x1 rule
988 double k_sum = 0.0;
989#pragma unroll
990 for ( int node = 0; node < 6; ++node )
991 {
992 const int ddx = WEDGE_NODE_OFF[w][node][0];
993 const int ddy = WEDGE_NODE_OFF[w][node][1];
994 const int ddr = WEDGE_NODE_OFF[w][node][2];
995
996 const int nid = node_id( tx + ddx, ty + ddy );
997 const int lvl = lvl0 + ddr;
998
999 k_sum += k_sh( nid, lvl );
1000 }
1001 const double k_eval = ONE_SIXTH * k_sum;
1002
1003 // Compute inverse Jacobian and |det J| for this wedge element
1004 double wJ = 0.0;
1005 double i00, i01, i02;
1006 double i10, i11, i12;
1007 double i20, i21, i22;
1008
1009 {
1010 const double half_dr = 0.5 * ( r_1 - r_0 );
1011 const double r_mid = 0.5 * ( r_0 + r_1 );
1012
1013 const double J_0_0 = r_mid * ( -ws[w][0][0] + ws[w][1][0] );
1014 const double J_0_1 = r_mid * ( -ws[w][0][0] + ws[w][2][0] );
1015 const double J_0_2 = half_dr * ( ONE_THIRD * ( ws[w][0][0] + ws[w][1][0] + ws[w][2][0] ) );
1016
1017 const double J_1_0 = r_mid * ( -ws[w][0][1] + ws[w][1][1] );
1018 const double J_1_1 = r_mid * ( -ws[w][0][1] + ws[w][2][1] );
1019 const double J_1_2 = half_dr * ( ONE_THIRD * ( ws[w][0][1] + ws[w][1][1] + ws[w][2][1] ) );
1020
1021 const double J_2_0 = r_mid * ( -ws[w][0][2] + ws[w][1][2] );
1022 const double J_2_1 = r_mid * ( -ws[w][0][2] + ws[w][2][2] );
1023 const double J_2_2 = half_dr * ( ONE_THIRD * ( ws[w][0][2] + ws[w][1][2] + ws[w][2][2] ) );
1024
1025 const double J_det = J_0_0 * J_1_1 * J_2_2 - J_0_0 * J_1_2 * J_2_1 - J_0_1 * J_1_0 * J_2_2 +
1026 J_0_1 * J_1_2 * J_2_0 + J_0_2 * J_1_0 * J_2_1 - J_0_2 * J_1_1 * J_2_0;
1027
1028 const double invJ = 1.0 / J_det;
1029
1030 i00 = invJ * ( J_1_1 * J_2_2 - J_1_2 * J_2_1 );
1031 i01 = invJ * ( -J_1_0 * J_2_2 + J_1_2 * J_2_0 );
1032 i02 = invJ * ( J_1_0 * J_2_1 - J_1_1 * J_2_0 );
1033 i10 = invJ * ( -J_0_1 * J_2_2 + J_0_2 * J_2_1 );
1034 i11 = invJ * ( J_0_0 * J_2_2 - J_0_2 * J_2_0 );
1035 i12 = invJ * ( -J_0_0 * J_2_1 + J_0_1 * J_2_0 );
1036 i20 = invJ * ( J_0_1 * J_1_2 - J_0_2 * J_1_1 );
1037 i21 = invJ * ( -J_0_0 * J_1_2 + J_0_2 * J_1_0 );
1038 i22 = invJ * ( J_0_0 * J_1_1 - J_0_1 * J_1_0 );
1039
1040 wJ = Kokkos::abs( J_det );
1041 }
1042
1043 const double kwJ = k_eval * wJ;
1044
1045 // Fused operator action:
1046 // - first pass builds grad(u) / div(u)-dependent accumulators
1047 // - second pass tests each basis function and accumulates to dst8
1048 double gu00 = 0.0;
1049 double gu10 = 0.0, gu11 = 0.0;
1050 double gu20 = 0.0, gu21 = 0.0, gu22 = 0.0;
1051 double div_u = 0.0;
1052
1053 if ( !diagonal_ )
1054 {
1055 // Build sym(grad u) and div(u) at quadrature point
1056 for ( int dimj = 0; dimj < 3; ++dimj )
1057 {
1058#pragma unroll
1059 for ( int node_idx = cmb_shift; node_idx < 6 - surface_shift; ++node_idx )
1060 {
1061 const double gx = dN_ref[node_idx][0];
1062 const double gy = dN_ref[node_idx][1];
1063 const double gz = dN_ref[node_idx][2];
1064
1065 const double g0 = i00 * gx + i01 * gy + i02 * gz;
1066 const double g1 = i10 * gx + i11 * gy + i12 * gz;
1067 const double g2 = i20 * gx + i21 * gy + i22 * gz;
1068
1069 double E00, E11, E22, sym01, sym02, sym12, gdd;
1070 column_grad_to_sym( dimj, g0, g1, g2, E00, E11, E22, sym01, sym02, sym12, gdd );
1071
1072 const int ddx = WEDGE_NODE_OFF[w][node_idx][0];
1073 const int ddy = WEDGE_NODE_OFF[w][node_idx][1];
1074 const int ddr = WEDGE_NODE_OFF[w][node_idx][2];
1075
1076 const int nid = node_id( tx + ddx, ty + ddy );
1077 const int lvl = lvl0 + ddr;
1078
1079 const double s = src_sh( nid, dimj, lvl );
1080
1081 gu00 += E00 * s;
1082 gu10 += sym01 * s;
1083 gu11 += E11 * s;
1084 gu20 += sym02 * s;
1085 gu21 += sym12 * s;
1086 gu22 += E22 * s;
1087 div_u += gdd * s;
1088 }
1089 }
1090
1091 // Test against each basis function and accumulate local contributions
1092 for ( int dimi = 0; dimi < 3; ++dimi )
1093 {
1094#pragma unroll
1095 for ( int node_idx = cmb_shift; node_idx < 6 - surface_shift; ++node_idx )
1096 {
1097 const double gx = dN_ref[node_idx][0];
1098 const double gy = dN_ref[node_idx][1];
1099 const double gz = dN_ref[node_idx][2];
1100
1101 const double g0 = i00 * gx + i01 * gy + i02 * gz;
1102 const double g1 = i10 * gx + i11 * gy + i12 * gz;
1103 const double g2 = i20 * gx + i21 * gy + i22 * gz;
1104
1105 double E00, E11, E22, sym01, sym02, sym12, gdd;
1106 column_grad_to_sym( dimi, g0, g1, g2, E00, E11, E22, sym01, sym02, sym12, gdd );
1107
1108 const int u = WEDGE_TO_UNIQUE[w][node_idx];
1109
1110 dst8[dimi][u] += kwJ * ( NEG_TWO_THIRDS * div_u * gdd +
1111 4.0 * sym01 * gu10 +
1112 4.0 * sym02 * gu20 +
1113 4.0 * sym12 * gu21 +
1114 2.0 * E00 * gu00 +
1115 2.0 * E11 * gu11 +
1116 2.0 * E22 * gu22 );
1117 }
1118 }
1119 }
1120
1121 // Diagonal-only mode (or diagonal correction on DIRICHLET boundaries)
1122 if ( diagonal_ || ( treat_boundary_dirichlet && at_boundary ) )
1123 {
1124 for ( int dim_diagBC = 0; dim_diagBC < 3; ++dim_diagBC )
1125 {
1126#pragma unroll
1127 for ( int node_idx = surface_shift; node_idx < 6 - cmb_shift; ++node_idx )
1128 {
1129 const double gx = dN_ref[node_idx][0];
1130 const double gy = dN_ref[node_idx][1];
1131 const double gz = dN_ref[node_idx][2];
1132
1133 const double g0 = i00 * gx + i01 * gy + i02 * gz;
1134 const double g1 = i10 * gx + i11 * gy + i12 * gz;
1135 const double g2 = i20 * gx + i21 * gy + i22 * gz;
1136
1137 double E00, E11, E22, sym01, sym02, sym12, gdd;
1138 column_grad_to_sym( dim_diagBC, g0, g1, g2, E00, E11, E22, sym01, sym02, sym12, gdd );
1139
1140 const int ddx = WEDGE_NODE_OFF[w][node_idx][0];
1141 const int ddy = WEDGE_NODE_OFF[w][node_idx][1];
1142 const int ddr = WEDGE_NODE_OFF[w][node_idx][2];
1143
1144 const int nid = node_id( tx + ddx, ty + ddy );
1145 const int lvl = lvl0 + ddr;
1146
1147 const double s = src_sh( nid, dim_diagBC, lvl );
1148 const int u = WEDGE_TO_UNIQUE[w][node_idx];
1149
1150 dst8[dim_diagBC][u] +=
1151 kwJ * ( 4.0 * s * ( sym01 * sym01 + sym02 * sym02 + sym12 * sym12 ) +
1152 2.0 * s * ( E00 * E00 + E11 * E11 + E22 * E22 ) +
1153 NEG_TWO_THIRDS * ( gdd * gdd ) * s );
1154 }
1155 }
1156 }
1157 } // wedge loop
1158
1159 // Final atomic scatter to global dst vector (8 hex nodes x 3 components)
1160 for ( int dim_add = 0; dim_add < 3; ++dim_add )
1161 {
1162 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell, r_cell, dim_add ), dst8[dim_add][0] );
1163 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell, r_cell, dim_add ), dst8[dim_add][1] );
1164 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell + 1, r_cell, dim_add ), dst8[dim_add][2] );
1165 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell, r_cell + 1, dim_add ), dst8[dim_add][3] );
1166 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell, r_cell + 1, dim_add ), dst8[dim_add][4] );
1167 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell, y_cell + 1, r_cell + 1, dim_add ), dst8[dim_add][5] );
1168 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell + 1, r_cell, dim_add ), dst8[dim_add][6] );
1169 Kokkos::atomic_add( &dst_( local_subdomain_id, x_cell + 1, y_cell + 1, r_cell + 1, dim_add ), dst8[dim_add][7] );
1170 }
1171 }
1172
1173 public:
1174 /**
1175 * @brief Fallback functor entry point.
1176 *
1177 * Not used by apply_impl anymore (which launches specialized kernels directly),
1178 * but kept for compatibility/debug.
1179 */
1180 KOKKOS_INLINE_FUNCTION
1181 void operator()( const Team& team ) const
1182 {
1183 if ( use_slow_path_ ) { operator_slow_kernel( team ); }
1184 else { operator_fast_kernel( team ); }
1185 }
1186
1187 // -------------------------------------------------------------------------
1188 // The remaining methods are unchanged numerically; only comments added.
1189 // -------------------------------------------------------------------------
1190
1191 /**
1192 * @brief Build trial/test symmetric-gradient vectors at one quadrature point and wedge.
1193 *
1194 * For a given pair of vector components (dimi,d
1195imj), this computes:
1196 * - sym_grad_i / sym_grad_j for all local wedge nodes
1197 * - scalar quadrature factor = weight * k(x_q) * |det J|
1198 *
1199 * These vectors are reused both for:
1200 * - fused local matvecs, and
1201 * - explicit local matrix assembly (outer products)
1202 */
1203 KOKKOS_INLINE_FUNCTION void assemble_trial_test_vecs(
1204 const int wedge,
1205 const dense::Vec< ScalarType, VecDim >& quad_point,
1206 const ScalarType quad_weight,
1207 const ScalarT r_1,
1208 const ScalarT r_2,
1209 dense::Vec< ScalarT, 3 > ( *wedge_phy_surf )[3],
1210 const dense::Vec< ScalarT, 6 >* k_local_hex,
1211 const int dimi,
1212 const int dimj,
1215 ScalarType& jdet_keval_quadweight ) const
1216 {
1217 dense::Mat< ScalarType, VecDim, VecDim > J = jac( wedge_phy_surf[wedge], r_1, r_2, quad_point );
1218 const auto det = J.det();
1219 const auto abs_det = Kokkos::abs( det );
1220 const dense::Mat< ScalarType, VecDim, VecDim > J_inv_transposed = J.inv_transposed( det );
1221
1222 ScalarType k_eval = 0.0;
1223 for ( int k = 0; k < num_nodes_per_wedge; k++ )
1224 {
1225 k_eval += shape( k, quad_point ) * k_local_hex[wedge]( k );
1226 }
1227
1228 for ( int k = 0; k < num_nodes_per_wedge; k++ )
1229 {
1230 sym_grad_i[k] = symmetric_grad( J_inv_transposed, quad_point, k, dimi );
1231 sym_grad_j[k] = symmetric_grad( J_inv_transposed, quad_point, k, dimj );
1232 }
1233
1234 jdet_keval_quadweight = quad_weight * k_eval * abs_det;
1235 }
1236
1237 /**
1238 * @brief Assemble one wedge-local 18x18 matrix on-the-fly.
1239 *
1240 * This is used by the slow path when local matrices are not stored, and by GCA workflows.
1241 */
1242 KOKKOS_INLINE_FUNCTION
1244 const int local_subdomain_id,
1245 const int x_cell,
1246 const int y_cell,
1247 const int r_cell,
1248 const int wedge ) const
1249 {
1251 wedge_surface_physical_coords( wedge_phy_surf, grid_, local_subdomain_id, x_cell, y_cell );
1252
1253 const ScalarT r_1 = radii_( local_subdomain_id, r_cell );
1254 const ScalarT r_2 = radii_( local_subdomain_id, r_cell + 1 );
1255
1257 extract_local_wedge_scalar_coefficients( k_local_hex, local_subdomain_id, x_cell, y_cell, r_cell, k_ );
1258
1260
1261 for ( int dimi = 0; dimi < 3; ++dimi )
1262 {
1263 for ( int dimj = 0; dimj < 3; ++dimj )
1264 {
1265 for ( int q = 0; q < num_quad_points; q++ )
1266 {
1269 ScalarType jdet_keval_quadweight = 0;
1270
1272 wedge,
1273 quad_points[q],
1274 quad_weights[q],
1275 r_1,
1276 r_2,
1277 wedge_phy_surf,
1278 k_local_hex,
1279 dimi,
1280 dimj,
1281 sym_grad_i,
1282 sym_grad_j,
1283 jdet_keval_quadweight );
1284
1285 for ( int i = 0; i < num_nodes_per_wedge; i++ )
1286 {
1287 for ( int j = 0; j < num_nodes_per_wedge; j++ )
1288 {
1289 A( i + dimi * num_nodes_per_wedge, j + dimj * num_nodes_per_wedge ) +=
1290 jdet_keval_quadweight *
1291 ( 2 * sym_grad_j[j].double_contract( sym_grad_i[i] ) -
1292 2.0 / 3.0 * sym_grad_j[j]( dimj, dimj ) * sym_grad_i[i]( dimi, dimi ) );
1293 }
1294 }
1295 }
1296 }
1297 }
1298
1299 return A;
1300 }
1301};
1302
1305
1306} // namespace terra::fe::wedge::operators::shell::epsdivdiv_history
Definition communication_plan.hpp:33
Send and receive buffers for all process-local subdomain boundaries.
Definition communication.hpp:56
Matrix-free / matrix-based epsilon-div-div operator on wedge elements in a spherical shell.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:51
dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > assemble_local_matrix(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int wedge) const
Assemble one wedge-local 18x18 matrix on-the-fly.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:1243
void set_boundary_conditions(BoundaryConditions bcs)
Optional runtime BC update; also refreshes fast/slow dispatch decision.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:211
void set_operator_apply_and_communication_modes(const linalg::OperatorApplyMode operator_apply_mode, const linalg::OperatorCommunicationMode operator_communication_mode)
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:200
bool has_flag(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, grid::shell::ShellBoundaryFlag flag) const
Convenience wrapper for shell boundary mask checks.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:225
void assemble_trial_test_vecs(const int wedge, const dense::Vec< ScalarType, VecDim > &quad_point, const ScalarType quad_weight, const ScalarT r_1, const ScalarT r_2, dense::Vec< ScalarT, 3 >(*wedge_phy_surf)[3], const dense::Vec< ScalarT, 6 > *k_local_hex, const int dimi, const int dimj, dense::Mat< ScalarType, VecDim, VecDim > *sym_grad_i, dense::Mat< ScalarType, VecDim, VecDim > *sym_grad_j, ScalarType &jdet_keval_quadweight) const
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:1203
dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > get_local_matrix(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int wedge) const
Get a local matrix, either from storage or assembled on-the-fly.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:277
void column_grad_to_sym(const int dim, const double g0, const double g1, const double g2, double &E00, double &E11, double &E22, double &sym01, double &sym02, double &sym12, double &gdd) const
Convert one gradient column (for vector component dim) into the symmetric-gradient entries.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:364
void set_stored_matrix_mode(linalg::OperatorStoredMatrixMode operator_stored_matrix_mode, int level_range, grid::Grid4DDataScalar< ScalarType > GCAElements)
Configure local matrix storage mode (Off / Selective / Full).
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:241
size_t team_shmem_size(const int) const
Team scratch requirement for the fast path.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:411
grid::Grid3DDataVec< ScalarT, 3 > get_grid()
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:221
static constexpr int LocalMatrixDim
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:56
void operator()(const Team &team) const
Fallback functor entry point.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:1181
const grid::shell::DistributedDomain & get_domain() const
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:219
void apply_impl(const SrcVectorType &src, DstVectorType &dst)
Apply operator to src and accumulate/replace into dst.
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:302
grid::Grid2DDataScalar< ScalarT > get_radii() const
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:220
const grid::Grid4DDataScalar< ScalarType > & k_grid_data()
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:218
Kokkos::TeamPolicy<>::member_type Team
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:60
terra::grid::Grid4DDataMatrices< ScalarType, LocalMatrixDim, LocalMatrixDim, 2 > Grid4DDataLocalMatrices
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:58
EpsilonDivDivKerngenV07SplitPaths(const grid::shell::DistributedDomain &domain, const grid::Grid3DDataVec< ScalarT, 3 > &grid, const grid::Grid2DDataScalar< ScalarT > &radii, const grid::Grid4DDataScalar< grid::shell::ShellBoundaryFlag > &mask, const grid::Grid4DDataScalar< ScalarT > &k, BoundaryConditions bcs, bool diagonal, linalg::OperatorApplyMode operator_apply_mode=linalg::OperatorApplyMode::Replace, linalg::OperatorCommunicationMode operator_communication_mode=linalg::OperatorCommunicationMode::CommunicateAdditively, linalg::OperatorStoredMatrixMode operator_stored_matrix_mode=linalg::OperatorStoredMatrixMode::Off)
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:142
linalg::OperatorStoredMatrixMode get_stored_matrix_mode()
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:257
ScalarT ScalarType
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:55
void set_diagonal(bool v)
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:208
void set_local_matrix(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int wedge, const dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > &mat) const
Store a local element matrix (used in GCA/coarsening workflows).
Definition epsilon_divdiv_kerngen_v07_split_paths.hpp:261
Parallel data structure organizing the thick spherical shell metadata for distributed (MPI parallel) ...
Definition spherical_shell.hpp:2518
const std::map< SubdomainInfo, std::tuple< LocalSubdomainIdx, SubdomainNeighborhood > > & subdomains() const
Definition spherical_shell.hpp:2650
const DomainInfo & domain_info() const
Returns a const reference.
Definition spherical_shell.hpp:2647
Information about the thick spherical shell mesh.
Definition spherical_shell.hpp:780
const std::vector< double > & radii() const
Definition spherical_shell.hpp:845
int diamond_lateral_refinement_level() const
Definition spherical_shell.hpp:843
int subdomain_num_nodes_radially() const
Equivalent to calling subdomain_num_nodes_radially( subdomain_refinement_level() )
Definition spherical_shell.hpp:861
int subdomain_num_nodes_per_side_laterally() const
Equivalent to calling subdomain_num_nodes_per_side_laterally( subdomain_refinement_level() )
Definition spherical_shell.hpp:852
Static assertion: VectorQ1Scalar satisfies VectorLike concept.
Definition vector_q1.hpp:168
const grid::Grid4DDataVec< ScalarType, VecDim > & grid_data() const
Get const reference to grid data.
Definition vector_q1.hpp:288
bool has_matrix(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int wedge) const
Checks for presence of a local matrix for a certain element.
Definition local_matrix_storage.hpp:223
dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > get_matrix(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int wedge) const
Retrives the local matrix if there is stored local matrices, the desired local matrix is loaded and r...
Definition local_matrix_storage.hpp:175
void set_matrix(const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int wedge, dense::Mat< ScalarT, LocalMatrixDim, LocalMatrixDim > mat) const
Set the local matrix stored in the operator.
Definition local_matrix_storage.hpp:118
Timer supporting RAII scope or manual stop.
Definition timer.hpp:342
void stop()
Stop the timer and record elapsed time.
Definition timer.hpp:364
Concept for types that can be used as Galerkin coarse-grid operators in a multigrid hierarchy....
Definition operator.hpp:81
k_eval
Definition EpsilonDivDiv_kernel_gen.py:175
g2
Definition EpsilonDivDiv_kernel_gen.py:267
div_u
Definition EpsilonDivDiv_kernel_gen.py:359
cmb_shift
Definition EpsilonDivDiv_kernel_gen.py:98
g1
Definition EpsilonDivDiv_kernel_gen.py:266
J_det
Definition EpsilonDivDiv_kernel_gen.py:216
surface_shift
Definition EpsilonDivDiv_kernel_gen.py:98
constexpr int idx(const int loop_idx, const int size, const grid::BoundaryPosition position, const grid::BoundaryDirection direction)
Definition buffer_copy_kernels.hpp:73
void send_recv_with_plan(const ShellBoundaryCommPlan< GridDataType > &plan, const GridDataType &data, SubdomainNeighborhoodSendRecvBuffer< typename GridDataType::value_type, grid::grid_data_vec_dim< GridDataType >() > &recv_buffers, CommunicationReduction reduction=CommunicationReduction::SUM)
Definition communication_plan.hpp:652
Definition epsilon_divdiv_kerngen_v01_initial.hpp:16
constexpr void quad_felippa_1x1_quad_points(dense::Vec< T, 3 >(&quad_points)[quad_felippa_1x1_num_quad_points])
Definition wedge/quadrature/quadrature.hpp:36
constexpr void quad_felippa_1x1_quad_weights(T(&quad_weights)[quad_felippa_1x1_num_quad_points])
Definition wedge/quadrature/quadrature.hpp:43
constexpr int quad_felippa_1x1_num_quad_points
Definition wedge/quadrature/quadrature.hpp:32
constexpr int num_nodes_per_wedge_surface
Definition kernel_helpers.hpp:6
constexpr dense::Mat< T, 3, 3 > symmetric_grad(const dense::Mat< T, 3, 3 > &J_inv_transposed, const dense::Vec< T, 3 > &quad_point, const int dof, const int dim)
Returns the symmetric gradient of the shape function of a dof at a quadrature point.
Definition integrands.hpp:685
void wedge_surface_physical_coords(dense::Vec< T, 3 >(&wedge_surf_phy_coords)[num_wedges_per_hex_cell][num_nodes_per_wedge_surface], const grid::Grid3DDataVec< T, 3 > &lateral_grid, const int local_subdomain_id, const int x_cell, const int y_cell)
Extracts the (unit sphere) surface vertex coords of the two wedges of a hex cell.
Definition kernel_helpers.hpp:26
constexpr void reorder_local_dofs(const DoFOrdering doo_from, const DoFOrdering doo_to, dense::Vec< ScalarT, 18 > &dofs)
Definition kernel_helpers.hpp:619
void atomically_add_local_wedge_vector_coefficients(const grid::Grid4DDataVec< T, VecDim > &global_coefficients, const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int d, const dense::Vec< T, 6 > local_coefficients[2])
Performs an atomic add of the two local wedge coefficient vectors of a hex cell into the global coeff...
Definition kernel_helpers.hpp:465
constexpr int num_wedges_per_hex_cell
Definition kernel_helpers.hpp:5
void extract_local_wedge_scalar_coefficients(dense::Vec< T, 6 >(&local_coefficients)[2], const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const grid::Grid4DDataScalar< T > &global_coefficients)
Extracts the local vector coefficients for the two wedges of a hex cell from the global coefficient v...
Definition kernel_helpers.hpp:306
void extract_local_wedge_vector_coefficients(dense::Vec< T, 6 >(&local_coefficients)[2], const int local_subdomain_id, const int x_cell, const int y_cell, const int r_cell, const int d, const grid::Grid4DDataVec< T, VecDim > &global_coefficients)
Extracts the local vector coefficients for the two wedges of a hex cell from the global coefficient v...
Definition kernel_helpers.hpp:356
constexpr int num_nodes_per_wedge
Definition kernel_helpers.hpp:7
constexpr T shape(const int node_idx, const T xi, const T eta, const T zeta)
(Tensor-product) Shape function.
Definition integrands.hpp:146
constexpr dense::Mat< T, 3, 3 > jac(const dense::Vec< T, 3 > &p1_phy, const dense::Vec< T, 3 > &p2_phy, const dense::Vec< T, 3 > &p3_phy, const T r_1, const T r_2, const T xi, const T eta, const T zeta)
Definition integrands.hpp:657
dense::Vec< typename CoordsShellType::value_type, 3 > coords(const int subdomain, const int x, const int y, const int r, const CoordsShellType &coords_shell, const CoordsRadiiType &coords_radii)
Definition spherical_shell.hpp:2871
BoundaryConditionMapping[2] BoundaryConditions
Definition shell/bit_masks.hpp:37
ShellBoundaryFlag
FlagLike that indicates boundary types for the thick spherical shell.
Definition shell/bit_masks.hpp:12
BoundaryConditionFlag get_boundary_condition_flag(const BoundaryConditions bcs, ShellBoundaryFlag sbf)
Retrieve the boundary condition flag that is associated with a location in the shell e....
Definition shell/bit_masks.hpp:42
BoundaryConditionFlag
FlagLike that indicates the type of boundary condition
Definition shell/bit_masks.hpp:25
Kokkos::View< dense::Mat< ScalarType, Rows, Cols > ****[NumMatrices], Layout > Grid4DDataMatrices
Definition grid_types.hpp:173
Kokkos::View< ScalarType ***[VecDim], Layout > Grid3DDataVec
Definition grid_types.hpp:42
Kokkos::View< ScalarType ****, Layout > Grid4DDataScalar
Definition grid_types.hpp:27
Kokkos::View< ScalarType **, Layout > Grid2DDataScalar
Definition grid_types.hpp:21
dense::Mat< ScalarType, 3, 3 > trafo_mat_cartesian_to_normal_tangential(const dense::Vec< ScalarType, 3 > &n_input)
Constructs a robust orthonormal transformation matrix from Cartesian to (normal–tangential–tangential...
Definition local_basis_trafo_normal_tangential.hpp:36
OperatorApplyMode
Modes for applying an operator to a vector.
Definition operator.hpp:30
@ Replace
Overwrite the destination vector.
OperatorStoredMatrixMode
Modes for applying stored matrices.
Definition operator.hpp:47
@ Selective
Use stored matrices on selected, marked elements only, assemble on all others.
@ Full
Use stored matrices on all elements.
@ Off
Do not use stored matrices.
OperatorCommunicationMode
Modes for communication during operator application.
Definition operator.hpp:40
@ CommunicateAdditively
Communicate and add results.
constexpr bool has_flag(E mask_value, E flag) noexcept
Checks if a bitmask value contains a specific flag.
Definition bit_masking.hpp:43
detail::PrefixCout logroot([]() { return detail::log_prefix();})
std::ostream subclass that just logs on root and adds a timestamp for each line.
Definition mat.hpp:10
constexpr Mat diagonal() const
Definition mat.hpp:377
constexpr Mat< T, Cols, Rows > transposed() const
Definition mat.hpp:187
Mat & hadamard_product(const Mat &mat)
Definition mat.hpp:213
T double_contract(const Mat &mat)
Definition mat.hpp:226
SoA (Structure-of-Arrays) 4D vector grid data.
Definition grid_types.hpp:51