36 static constexpr int VecDim = grid::grid_data_vec_dim< GridDataType >();
41 : domain_( &domain ), enable_local_comm_( enable_local_comm )
44 allocate_rank_buffers_();
49 const GridDataType& data,
53 util::Timer timer_all(
"shell_boundary_exchange_and_reduce" );
57 local_comm_copy_into_recv_buffers_( data, boundary_recv_buffers );
59 pack_remote_sends_( data );
65 scatter_recvs_into_boundary_buffers_( boundary_recv_buffers );
67 unpack_and_reduce_( data, boundary_recv_buffers, reduction );
74 allocate_rank_buffers_();
80 int boundary_type = -1;
83 int local_subdomain_boundary;
84 int local_subdomain_id;
88 int neighbor_subdomain_boundary;
101 int piece_num_scalars_(
const SendRecvPair& p )
const
103 const auto& domain = *domain_;
105 if ( p.boundary_type == 0 )
109 else if ( p.boundary_type == 1 )
111 const auto local_edge_boundary =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
114 domain.domain_info().subdomain_num_nodes_per_side_laterally();
117 else if ( p.boundary_type == 2 )
119 const auto local_face_boundary =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
120 const int ni = domain.domain_info().subdomain_num_nodes_per_side_laterally();
122 domain.domain_info().subdomain_num_nodes_per_side_laterally() :
123 domain.domain_info().subdomain_num_nodes_radially();
126 Kokkos::abort(
"Unknown boundary type" );
132 util::Timer timer(
"ShellBoundaryCommPlan::build_plan" );
134 const auto& domain = *domain_;
136 send_recv_pairs_.clear();
137 send_recv_pairs_.reserve( 1024 );
140 for (
const auto& [local_subdomain_info, idx_and_neighborhood] : domain.subdomains() )
144 for (
const auto& [local_vertex_boundary, neighbors] : neighborhood.neighborhood_vertex() )
146 for (
const auto& neighbor : neighbors )
148 const auto& [neighbor_subdomain_info, neighbor_local_boundary, neighbor_rank] = neighbor;
149 send_recv_pairs_.push_back( SendRecvPair{
152 .local_subdomain = local_subdomain_info,
153 .local_subdomain_boundary =
static_cast< int >( local_vertex_boundary ),
154 .local_subdomain_id = local_subdomain_id,
155 .neighbor_rank = neighbor_rank,
156 .neighbor_subdomain = neighbor_subdomain_info,
157 .neighbor_subdomain_boundary =
static_cast< int >( neighbor_local_boundary ) } );
161 for (
const auto& [local_edge_boundary, neighbors] : neighborhood.neighborhood_edge() )
163 for (
const auto& neighbor : neighbors )
165 const auto& [neighbor_subdomain_info, neighbor_local_boundary, _, neighbor_rank] = neighbor;
166 send_recv_pairs_.push_back( SendRecvPair{
169 .local_subdomain = local_subdomain_info,
170 .local_subdomain_boundary =
static_cast< int >( local_edge_boundary ),
171 .local_subdomain_id = local_subdomain_id,
172 .neighbor_rank = neighbor_rank,
173 .neighbor_subdomain = neighbor_subdomain_info,
174 .neighbor_subdomain_boundary =
static_cast< int >( neighbor_local_boundary ) } );
178 for (
const auto& [local_face_boundary, neighbor] : neighborhood.neighborhood_face() )
180 const auto& [neighbor_subdomain_info, neighbor_local_boundary, _, neighbor_rank] = neighbor;
181 send_recv_pairs_.push_back( SendRecvPair{
184 .local_subdomain = local_subdomain_info,
185 .local_subdomain_boundary =
static_cast< int >( local_face_boundary ),
186 .local_subdomain_id = local_subdomain_id,
187 .neighbor_rank = neighbor_rank,
188 .neighbor_subdomain = neighbor_subdomain_info,
189 .neighbor_subdomain_boundary =
static_cast< int >( neighbor_local_boundary ) } );
194 local_pairs_.clear();
195 local_pairs_.reserve( send_recv_pairs_.size() );
196 for (
const auto& p : send_recv_pairs_ )
198 if ( enable_local_comm_ && p.local_rank == p.neighbor_rank )
199 local_pairs_.push_back( p );
204 auto send_pairs = send_recv_pairs_;
205 std::sort( send_pairs.begin(), send_pairs.end(), [](
const SendRecvPair& a,
const SendRecvPair& b ) {
206 if ( a.boundary_type != b.boundary_type ) return a.boundary_type < b.boundary_type;
207 if ( a.local_subdomain != b.local_subdomain ) return a.local_subdomain < b.local_subdomain;
208 if ( a.local_subdomain_boundary != b.local_subdomain_boundary )
209 return a.local_subdomain_boundary < b.local_subdomain_boundary;
210 if ( a.neighbor_subdomain != b.neighbor_subdomain ) return a.neighbor_subdomain < b.neighbor_subdomain;
211 return a.neighbor_subdomain_boundary < b.neighbor_subdomain_boundary;
214 send_chunks_by_rank_.clear();
215 send_total_by_rank_.clear();
217 for (
const auto& p : send_pairs )
219 if ( enable_local_comm_ && p.local_rank == p.neighbor_rank )
222 const int sz = piece_num_scalars_( p );
223 auto& chunks = send_chunks_by_rank_[p.neighbor_rank];
225 const int off = send_total_by_rank_[p.neighbor_rank];
226 send_total_by_rank_[p.neighbor_rank] += sz;
228 chunks.push_back( ChunkInfo{ .pair = p, .offset = off, .size = sz } );
234 auto recv_pairs = send_recv_pairs_;
235 std::sort( recv_pairs.begin(), recv_pairs.end(), [](
const SendRecvPair& a,
const SendRecvPair& b ) {
236 if ( a.boundary_type != b.boundary_type ) return a.boundary_type < b.boundary_type;
237 if ( a.neighbor_subdomain != b.neighbor_subdomain ) return a.neighbor_subdomain < b.neighbor_subdomain;
238 if ( a.neighbor_subdomain_boundary != b.neighbor_subdomain_boundary )
239 return a.neighbor_subdomain_boundary < b.neighbor_subdomain_boundary;
240 if ( a.local_subdomain != b.local_subdomain ) return a.local_subdomain < b.local_subdomain;
241 return a.local_subdomain_boundary < b.local_subdomain_boundary;
244 recv_chunks_by_rank_.clear();
245 recv_total_by_rank_.clear();
247 for (
const auto& p : recv_pairs )
249 if ( enable_local_comm_ && p.local_rank == p.neighbor_rank )
252 const int sz = piece_num_scalars_( p );
253 auto& chunks = recv_chunks_by_rank_[p.neighbor_rank];
255 const int off = recv_total_by_rank_[p.neighbor_rank];
256 recv_total_by_rank_[p.neighbor_rank] += sz;
258 chunks.push_back( ChunkInfo{ .pair = p, .offset = off, .size = sz } );
263 void allocate_rank_buffers_()
265 util::Timer timer(
"ShellBoundaryCommPlan::allocate_rank_buffers" );
267 send_rank_buffers_.clear();
268 recv_rank_buffers_.clear();
270 for (
const auto& [rank, total] : send_total_by_rank_ )
275 for (
const auto& [rank, total] : recv_total_by_rank_ )
281 data_send_requests_.resize( send_rank_buffers_.size() );
282 data_recv_requests_.resize( recv_rank_buffers_.size() );
288 void post_irecvs_()
const
290 util::Timer timer(
"ShellBoundaryCommPlan::post_irecvs" );
293 for (
const auto& [rank, buf] : recv_rank_buffers_ )
295 const int total_sz =
static_cast< int >( buf.extent( 0 ) );
299 mpi::mpi_datatype< ScalarType >(),
303 &data_recv_requests_[i] );
309 void local_comm_copy_into_recv_buffers_(
310 const GridDataType& data,
313 util::Timer timer(
"ShellBoundaryCommPlan::local_comm" );
315 const auto& domain = *domain_;
317 for (
const auto& p : local_pairs_ )
319 if ( !domain.subdomains().contains( p.neighbor_subdomain ) )
320 Kokkos::abort(
"Subdomain not found locally - but it should be there..." );
322 const auto neighbor_subdomain_id = std::get< 0 >( domain.subdomains().at( p.neighbor_subdomain ) );
324 if ( p.boundary_type == 0 )
329 p.neighbor_subdomain,
332 copy_to_buffer<VecDim>(
335 neighbor_subdomain_id,
338 else if ( p.boundary_type == 1 )
340 auto& recv_buf = boundary_recv_buffers.
buffer_edge(
343 p.neighbor_subdomain,
346 copy_to_buffer<VecDim>(
349 neighbor_subdomain_id,
352 else if ( p.boundary_type == 2 )
354 auto& recv_buf = boundary_recv_buffers.
buffer_face(
357 p.neighbor_subdomain,
360 copy_to_buffer<VecDim>(
363 neighbor_subdomain_id,
368 Kokkos::abort(
"Unknown boundary type" );
373 void pack_remote_sends_(
const GridDataType& data )
const
375 util::Timer timer(
"ShellBoundaryCommPlan::pack_remote" );
377 const auto& domain = *domain_;
379 for (
const auto& [rank, chunks] : send_chunks_by_rank_ )
381 auto& rank_buf = send_rank_buffers_.at( rank );
383 for (
const auto& ch : chunks )
385 const auto& p = ch.pair;
386 ScalarType* base_ptr = rank_buf.data() + ch.offset;
388 if ( p.boundary_type == 0 )
390 using BufT = grid::Grid0DDataVec< ScalarType, VecDim >;
391 auto unmanaged = detail::make_unmanaged_like< BufT >( base_ptr );
393 copy_to_buffer<VecDim>(
396 p.local_subdomain_id,
399 else if ( p.boundary_type == 1 )
401 using BufT = grid::Grid1DDataVec< ScalarType, VecDim >;
402 const auto local_edge_boundary =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
404 domain.domain_info().subdomain_num_nodes_radially() :
405 domain.domain_info().subdomain_num_nodes_per_side_laterally();
407 auto unmanaged = detail::make_unmanaged_like< BufT >( base_ptr, n_nodes );
408 copy_to_buffer<VecDim>( unmanaged, data, p.local_subdomain_id, local_edge_boundary );
410 else if ( p.boundary_type == 2 )
412 using BufT = grid::Grid2DDataVec< ScalarType, VecDim >;
413 const auto local_face_boundary =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
414 const int ni = domain.domain_info().subdomain_num_nodes_per_side_laterally();
416 domain.domain_info().subdomain_num_nodes_per_side_laterally() :
417 domain.domain_info().subdomain_num_nodes_radially();
419 auto unmanaged = detail::make_unmanaged_like< BufT >( base_ptr, ni, nj );
420 copy_to_buffer<VecDim>( unmanaged, data, p.local_subdomain_id, local_face_boundary );
424 Kokkos::abort(
"Unknown boundary type" );
429 Kokkos::fence(
"pack_rank_send_buffers" );
432 void post_isends_()
const
434 util::Timer timer(
"ShellBoundaryCommPlan::post_isends" );
437 for (
const auto& [rank, buf] : send_rank_buffers_ )
439 const int total_sz =
static_cast< int >( buf.extent( 0 ) );
443 mpi::mpi_datatype< ScalarType >(),
447 &data_send_requests_[i] );
453 void wait_all_()
const
455 util::Timer timer(
"ShellBoundaryCommPlan::waitall" );
457 if ( send_req_count_ > 0 )
458 MPI_Waitall( send_req_count_, data_send_requests_.data(), MPI_STATUSES_IGNORE );
459 if ( recv_req_count_ > 0 )
460 MPI_Waitall( recv_req_count_, data_recv_requests_.data(), MPI_STATUSES_IGNORE );
463 void scatter_recvs_into_boundary_buffers_(
466 util::Timer timer(
"ShellBoundaryCommPlan::scatter_recvs" );
468 const auto& domain = *domain_;
470 for (
const auto& [rank, chunks] : recv_chunks_by_rank_ )
472 auto& rank_buf = recv_rank_buffers_.at( rank );
474 for (
const auto& ch : chunks )
476 const auto& p = ch.pair;
477 ScalarType* base_ptr = rank_buf.data() + ch.offset;
479 if ( p.boundary_type == 0 )
481 using BufT = grid::Grid0DDataVec< ScalarType, VecDim >;
482 auto unmanaged = detail::make_unmanaged_like< BufT >( base_ptr );
487 p.neighbor_subdomain,
490 Kokkos::deep_copy( recv_buf, unmanaged );
492 else if ( p.boundary_type == 1 )
494 using BufT = grid::Grid1DDataVec< ScalarType, VecDim >;
496 const auto local_edge_boundary =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
498 domain.domain_info().subdomain_num_nodes_radially() :
499 domain.domain_info().subdomain_num_nodes_per_side_laterally();
501 auto unmanaged = detail::make_unmanaged_like< BufT >( base_ptr, n_nodes );
503 auto& recv_buf = boundary_recv_buffers.
buffer_edge(
506 p.neighbor_subdomain,
509 Kokkos::deep_copy( recv_buf, unmanaged );
511 else if ( p.boundary_type == 2 )
513 using BufT = grid::Grid2DDataVec< ScalarType, VecDim >;
515 const auto local_face_boundary =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
516 const int ni = domain.domain_info().subdomain_num_nodes_per_side_laterally();
518 domain.domain_info().subdomain_num_nodes_per_side_laterally() :
519 domain.domain_info().subdomain_num_nodes_radially();
521 auto unmanaged = detail::make_unmanaged_like< BufT >( base_ptr, ni, nj );
523 auto& recv_buf = boundary_recv_buffers.
buffer_face(
526 p.neighbor_subdomain,
529 Kokkos::deep_copy( recv_buf, unmanaged );
533 Kokkos::abort(
"Unknown boundary type" );
538 Kokkos::fence(
"scatter_rank_recv_buffers" );
541 void unpack_and_reduce_(
542 const GridDataType& data,
546 util::Timer timer(
"ShellBoundaryCommPlan::unpack_and_reduce" );
548 const auto& domain = *domain_;
550 for (
const auto& [local_subdomain_info, idx_and_neighborhood] : domain.subdomains() )
554 for (
const auto& [local_vertex_boundary, neighbors] : neighborhood.neighborhood_vertex() )
556 for (
const auto& neighbor : neighbors )
558 const auto& [neighbor_subdomain_info, neighbor_local_boundary, neighbor_rank] = neighbor;
561 local_subdomain_info, local_vertex_boundary, neighbor_subdomain_info, neighbor_local_boundary );
564 recv_buffer, data, local_subdomain_id, local_vertex_boundary, reduction );
568 for (
const auto& [local_edge_boundary, neighbors] : neighborhood.neighborhood_edge() )
570 for (
const auto& neighbor : neighbors )
572 const auto& [neighbor_subdomain_info, neighbor_local_boundary, boundary_direction, neighbor_rank] =
575 auto recv_buffer = boundary_recv_buffers.
buffer_edge(
576 local_subdomain_info, local_edge_boundary, neighbor_subdomain_info, neighbor_local_boundary );
579 recv_buffer, data, local_subdomain_id, local_edge_boundary, boundary_direction, reduction );
583 for (
const auto& [local_face_boundary, neighbor] : neighborhood.neighborhood_face() )
585 const auto& [neighbor_subdomain_info, neighbor_local_boundary, boundary_directions, neighbor_rank] =
588 auto recv_buffer = boundary_recv_buffers.
buffer_face(
589 local_subdomain_info, local_face_boundary, neighbor_subdomain_info, neighbor_local_boundary );
592 recv_buffer, data, local_subdomain_id, local_face_boundary, boundary_directions, reduction );
600 const grid::shell::DistributedDomain* domain_ =
nullptr;
601 bool enable_local_comm_ =
true;
604 std::vector< SendRecvPair > send_recv_pairs_;
607 std::vector< SendRecvPair > local_pairs_;
610 std::map< mpi::MPIRank, std::vector< ChunkInfo > > send_chunks_by_rank_;
611 std::map< mpi::MPIRank, std::vector< ChunkInfo > > recv_chunks_by_rank_;
612 std::map< mpi::MPIRank, int > send_total_by_rank_;
613 std::map< mpi::MPIRank, int > recv_total_by_rank_;
616 mutable std::map< mpi::MPIRank, rank_buffer_view > send_rank_buffers_;
617 mutable std::map< mpi::MPIRank, rank_buffer_view > recv_rank_buffers_;
620 mutable std::vector< MPI_Request > data_send_requests_;
621 mutable std::vector< MPI_Request > data_recv_requests_;
622 mutable int send_req_count_ = 0;
623 mutable int recv_req_count_ = 0;