36 static constexpr int VecDim = grid::grid_data_vec_dim< GridDataType >();
41 : domain_( &domain ), enable_local_comm_( enable_local_comm )
44 allocate_rank_buffers_();
49 const GridDataType& data,
53 util::Timer timer_all(
"shell_boundary_exchange_and_reduce" );
57 local_comm_copy_into_recv_buffers_( data, boundary_recv_buffers );
59 pack_remote_sends_( data );
64 unpack_local_( data, boundary_recv_buffers, reduction );
68 wait_and_unpack_remote_( data, reduction );
77 allocate_rank_buffers_();
83 int boundary_type = -1;
86 int local_subdomain_boundary;
87 int local_subdomain_id;
91 int neighbor_subdomain_boundary;
109 int piece_num_scalars_(
const SendRecvPair& p )
const
111 const auto& domain = *domain_;
113 if ( p.boundary_type == 0 )
117 else if ( p.boundary_type == 1 )
119 const auto local_edge_boundary =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
122 domain.domain_info().subdomain_num_nodes_per_side_laterally();
125 else if ( p.boundary_type == 2 )
127 const auto local_face_boundary =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
128 const int ni = domain.domain_info().subdomain_num_nodes_per_side_laterally();
130 domain.domain_info().subdomain_num_nodes_per_side_laterally() :
131 domain.domain_info().subdomain_num_nodes_radially();
134 Kokkos::abort(
"Unknown boundary type" );
140 util::Timer timer(
"ShellBoundaryCommPlan::build_plan" );
142 const auto& domain = *domain_;
144 send_recv_pairs_.clear();
145 send_recv_pairs_.reserve( 1024 );
148 for (
const auto& [local_subdomain_info, idx_and_neighborhood] : domain.subdomains() )
152 for (
const auto& [local_vertex_boundary, neighbors] : neighborhood.neighborhood_vertex() )
154 for (
const auto& neighbor : neighbors )
156 const auto& [neighbor_subdomain_info, neighbor_local_boundary, neighbor_rank] = neighbor;
157 send_recv_pairs_.push_back( SendRecvPair{
159 .local_rank =
mpi::rank( domain.comm() ),
160 .local_subdomain = local_subdomain_info,
161 .local_subdomain_boundary =
static_cast< int >( local_vertex_boundary ),
162 .local_subdomain_id = local_subdomain_id,
163 .neighbor_rank = neighbor_rank,
164 .neighbor_subdomain = neighbor_subdomain_info,
165 .neighbor_subdomain_boundary =
static_cast< int >( neighbor_local_boundary ) } );
169 for (
const auto& [local_edge_boundary, neighbors] : neighborhood.neighborhood_edge() )
171 for (
const auto& neighbor : neighbors )
173 const auto& [neighbor_subdomain_info, neighbor_local_boundary, edge_direction, neighbor_rank] =
175 send_recv_pairs_.push_back( SendRecvPair{
177 .local_rank =
mpi::rank( domain.comm() ),
178 .local_subdomain = local_subdomain_info,
179 .local_subdomain_boundary =
static_cast< int >( local_edge_boundary ),
180 .local_subdomain_id = local_subdomain_id,
181 .neighbor_rank = neighbor_rank,
182 .neighbor_subdomain = neighbor_subdomain_info,
183 .neighbor_subdomain_boundary =
static_cast< int >( neighbor_local_boundary ),
184 .direction_0 = edge_direction } );
188 for (
const auto& [local_face_boundary, neighbor] : neighborhood.neighborhood_face() )
190 const auto& [neighbor_subdomain_info, neighbor_local_boundary, face_directions, neighbor_rank] =
192 send_recv_pairs_.push_back( SendRecvPair{
194 .local_rank =
mpi::rank( domain.comm() ),
195 .local_subdomain = local_subdomain_info,
196 .local_subdomain_boundary =
static_cast< int >( local_face_boundary ),
197 .local_subdomain_id = local_subdomain_id,
198 .neighbor_rank = neighbor_rank,
199 .neighbor_subdomain = neighbor_subdomain_info,
200 .neighbor_subdomain_boundary =
static_cast< int >( neighbor_local_boundary ),
201 .direction_0 = std::get< 0 >( face_directions ),
202 .direction_1 = std::get< 1 >( face_directions ) } );
207 local_pairs_.clear();
208 local_pairs_.reserve( send_recv_pairs_.size() );
209 for (
const auto& p : send_recv_pairs_ )
211 if ( enable_local_comm_ && p.local_rank == p.neighbor_rank )
212 local_pairs_.push_back( p );
217 auto send_pairs = send_recv_pairs_;
218 std::sort( send_pairs.begin(), send_pairs.end(), [](
const SendRecvPair& a,
const SendRecvPair& b ) {
219 if ( a.boundary_type != b.boundary_type ) return a.boundary_type < b.boundary_type;
220 if ( a.local_subdomain != b.local_subdomain ) return a.local_subdomain < b.local_subdomain;
221 if ( a.local_subdomain_boundary != b.local_subdomain_boundary )
222 return a.local_subdomain_boundary < b.local_subdomain_boundary;
223 if ( a.neighbor_subdomain != b.neighbor_subdomain ) return a.neighbor_subdomain < b.neighbor_subdomain;
224 return a.neighbor_subdomain_boundary < b.neighbor_subdomain_boundary;
227 send_chunks_by_rank_.clear();
228 send_total_by_rank_.clear();
230 for (
const auto& p : send_pairs )
232 if ( enable_local_comm_ && p.local_rank == p.neighbor_rank )
235 const int sz = piece_num_scalars_( p );
236 auto& chunks = send_chunks_by_rank_[p.neighbor_rank];
238 const int off = send_total_by_rank_[p.neighbor_rank];
239 send_total_by_rank_[p.neighbor_rank] += sz;
241 chunks.push_back( ChunkInfo{ .pair = p, .offset = off, .size = sz } );
247 auto recv_pairs = send_recv_pairs_;
248 std::sort( recv_pairs.begin(), recv_pairs.end(), [](
const SendRecvPair& a,
const SendRecvPair& b ) {
249 if ( a.boundary_type != b.boundary_type ) return a.boundary_type < b.boundary_type;
250 if ( a.neighbor_subdomain != b.neighbor_subdomain ) return a.neighbor_subdomain < b.neighbor_subdomain;
251 if ( a.neighbor_subdomain_boundary != b.neighbor_subdomain_boundary )
252 return a.neighbor_subdomain_boundary < b.neighbor_subdomain_boundary;
253 if ( a.local_subdomain != b.local_subdomain ) return a.local_subdomain < b.local_subdomain;
254 return a.local_subdomain_boundary < b.local_subdomain_boundary;
257 recv_chunks_by_rank_.clear();
258 recv_total_by_rank_.clear();
260 for (
const auto& p : recv_pairs )
262 if ( enable_local_comm_ && p.local_rank == p.neighbor_rank )
265 const int sz = piece_num_scalars_( p );
266 auto& chunks = recv_chunks_by_rank_[p.neighbor_rank];
268 const int off = recv_total_by_rank_[p.neighbor_rank];
269 recv_total_by_rank_[p.neighbor_rank] += sz;
271 chunks.push_back( ChunkInfo{ .pair = p, .offset = off, .size = sz } );
276 void allocate_rank_buffers_()
278 util::Timer timer(
"ShellBoundaryCommPlan::allocate_rank_buffers" );
280 send_rank_buffers_.clear();
281 recv_rank_buffers_.clear();
283 for (
const auto& [rank, total] : send_total_by_rank_ )
288 for (
const auto& [rank, total] : recv_total_by_rank_ )
294 data_send_requests_.resize( send_rank_buffers_.size() );
295 data_recv_requests_.resize( recv_rank_buffers_.size() );
296 recv_request_ranks_.resize( recv_rank_buffers_.size() );
302 void post_irecvs_()
const
304 util::Timer timer(
"ShellBoundaryCommPlan::post_irecvs" );
307 for (
const auto& [rank, buf] : recv_rank_buffers_ )
309 const int total_sz =
static_cast< int >( buf.extent( 0 ) );
313 mpi::mpi_datatype< ScalarType >(),
317 &data_recv_requests_[i] );
318 recv_request_ranks_[i] =
rank;
324 void local_comm_copy_into_recv_buffers_(
325 const GridDataType& data,
328 util::Timer timer(
"ShellBoundaryCommPlan::local_comm" );
330 const auto& domain = *domain_;
332 for (
const auto& p : local_pairs_ )
334 if ( !domain.subdomains().contains( p.neighbor_subdomain ) )
335 Kokkos::abort(
"Subdomain not found locally - but it should be there..." );
337 const auto neighbor_subdomain_id = std::get< 0 >( domain.subdomains().at( p.neighbor_subdomain ) );
339 if ( p.boundary_type == 0 )
344 p.neighbor_subdomain,
347 copy_to_buffer<VecDim>(
350 neighbor_subdomain_id,
353 else if ( p.boundary_type == 1 )
355 auto& recv_buf = boundary_recv_buffers.
buffer_edge(
358 p.neighbor_subdomain,
361 copy_to_buffer<VecDim>(
364 neighbor_subdomain_id,
367 else if ( p.boundary_type == 2 )
369 auto& recv_buf = boundary_recv_buffers.
buffer_face(
372 p.neighbor_subdomain,
375 copy_to_buffer<VecDim>(
378 neighbor_subdomain_id,
383 Kokkos::abort(
"Unknown boundary type" );
388 void pack_remote_sends_(
const GridDataType& data )
const
390 util::Timer timer(
"ShellBoundaryCommPlan::pack_remote" );
392 const auto& domain = *domain_;
394 for (
const auto& [rank, chunks] : send_chunks_by_rank_ )
396 auto& rank_buf = send_rank_buffers_.at( rank );
398 for (
const auto& ch : chunks )
400 const auto& p = ch.pair;
401 ScalarType* base_ptr = rank_buf.data() + ch.offset;
403 if ( p.boundary_type == 0 )
405 using BufT = grid::Grid0DDataVec< ScalarType, VecDim >;
408 copy_to_buffer<VecDim>(
411 p.local_subdomain_id,
414 else if ( p.boundary_type == 1 )
416 using BufT = grid::Grid1DDataVec< ScalarType, VecDim >;
417 const auto local_edge_boundary =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
419 domain.domain_info().subdomain_num_nodes_radially() :
420 domain.domain_info().subdomain_num_nodes_per_side_laterally();
423 copy_to_buffer<VecDim>( unmanaged, data, p.local_subdomain_id, local_edge_boundary );
425 else if ( p.boundary_type == 2 )
427 using BufT = grid::Grid2DDataVec< ScalarType, VecDim >;
428 const auto local_face_boundary =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
429 const int ni = domain.domain_info().subdomain_num_nodes_per_side_laterally();
431 domain.domain_info().subdomain_num_nodes_per_side_laterally() :
432 domain.domain_info().subdomain_num_nodes_radially();
435 copy_to_buffer<VecDim>( unmanaged, data, p.local_subdomain_id, local_face_boundary );
439 Kokkos::abort(
"Unknown boundary type" );
444 Kokkos::fence(
"pack_rank_send_buffers" );
447 void post_isends_()
const
449 util::Timer timer(
"ShellBoundaryCommPlan::post_isends" );
452 for (
const auto& [rank, buf] : send_rank_buffers_ )
454 const int total_sz =
static_cast< int >( buf.extent( 0 ) );
458 mpi::mpi_datatype< ScalarType >(),
462 &data_send_requests_[i] );
472 const GridDataType& data,
476 util::Timer timer(
"ShellBoundaryCommPlan::unpack_local" );
478 for (
const auto& p : local_pairs_ )
480 if ( p.boundary_type == 0 )
483 const auto neighbor_boundary =
static_cast< grid::BoundaryVertex >( p.neighbor_subdomain_boundary );
484 const auto& recv_buffer = boundary_recv_buffers.
buffer_vertex(
485 p.local_subdomain, local_boundary, p.neighbor_subdomain, neighbor_boundary );
487 recv_buffer, data, p.local_subdomain_id, local_boundary, reduction );
489 else if ( p.boundary_type == 1 )
491 const auto local_boundary =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
492 const auto neighbor_boundary =
static_cast< grid::BoundaryEdge >( p.neighbor_subdomain_boundary );
493 const auto& recv_buffer = boundary_recv_buffers.
buffer_edge(
494 p.local_subdomain, local_boundary, p.neighbor_subdomain, neighbor_boundary );
496 recv_buffer, data, p.local_subdomain_id, local_boundary, p.direction_0, reduction );
498 else if ( p.boundary_type == 2 )
500 const auto local_boundary =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
501 const auto neighbor_boundary =
static_cast< grid::BoundaryFace >( p.neighbor_subdomain_boundary );
502 const auto& recv_buffer = boundary_recv_buffers.
buffer_face(
503 p.local_subdomain, local_boundary, p.neighbor_subdomain, neighbor_boundary );
507 p.local_subdomain_id,
509 std::make_tuple( p.direction_0, p.direction_1 ),
514 Kokkos::abort(
"Unknown boundary type" );
521 void unpack_remote_rank_(
522 const GridDataType& data,
526 const auto& domain = *domain_;
527 auto& rank_buf = recv_rank_buffers_.at( rank );
528 const auto& chunks = recv_chunks_by_rank_.at( rank );
530 for (
const auto& ch : chunks )
532 const auto& p = ch.pair;
533 ScalarType* base_ptr = rank_buf.data() + ch.offset;
535 if ( p.boundary_type == 0 )
537 using BufT = grid::Grid0DDataVec< ScalarType, VecDim >;
542 p.local_subdomain_id,
546 else if ( p.boundary_type == 1 )
548 using BufT = grid::Grid1DDataVec< ScalarType, VecDim >;
549 const auto local_edge =
static_cast< grid::BoundaryEdge >( p.local_subdomain_boundary );
551 domain.domain_info().subdomain_num_nodes_radially() :
552 domain.domain_info().subdomain_num_nodes_per_side_laterally();
555 unmanaged, data, p.local_subdomain_id, local_edge, p.direction_0, reduction );
557 else if ( p.boundary_type == 2 )
559 using BufT = grid::Grid2DDataVec< ScalarType, VecDim >;
560 const auto local_face =
static_cast< grid::BoundaryFace >( p.local_subdomain_boundary );
561 const int ni = domain.domain_info().subdomain_num_nodes_per_side_laterally();
563 domain.domain_info().subdomain_num_nodes_per_side_laterally() :
564 domain.domain_info().subdomain_num_nodes_radially();
569 p.local_subdomain_id,
571 std::make_tuple( p.direction_0, p.direction_1 ),
576 Kokkos::abort(
"Unknown boundary type" );
590 void wait_and_unpack_remote_(
591 const GridDataType& data,
594 util::Timer timer(
"ShellBoundaryCommPlan::waitall" );
596 for (
int completed = 0; completed < recv_req_count_; ++completed )
598 int idx = MPI_UNDEFINED;
599 if ( completed == 0 )
601 util::Timer t(
"ShellBoundaryCommPlan::mpi_waitany_first" );
602 MPI_Waitany( recv_req_count_, data_recv_requests_.data(), &idx, MPI_STATUS_IGNORE );
606 util::Timer t(
"ShellBoundaryCommPlan::mpi_waitany_rest" );
607 MPI_Waitany( recv_req_count_, data_recv_requests_.data(), &idx, MPI_STATUS_IGNORE );
609 unpack_remote_rank_( data, recv_request_ranks_[idx], reduction );
612 if ( send_req_count_ > 0 )
614 util::Timer t(
"ShellBoundaryCommPlan::mpi_waitall_sends" );
615 MPI_Waitall( send_req_count_, data_send_requests_.data(), MPI_STATUSES_IGNORE );
620 const grid::shell::DistributedDomain* domain_ =
nullptr;
621 bool enable_local_comm_ =
true;
624 std::vector< SendRecvPair > send_recv_pairs_;
627 std::vector< SendRecvPair > local_pairs_;
630 std::map< mpi::MPIRank, std::vector< ChunkInfo > > send_chunks_by_rank_;
631 std::map< mpi::MPIRank, std::vector< ChunkInfo > > recv_chunks_by_rank_;
632 std::map< mpi::MPIRank, int > send_total_by_rank_;
633 std::map< mpi::MPIRank, int > recv_total_by_rank_;
637 mutable std::map< mpi::MPIRank, rank_buffer_view > send_rank_buffers_;
638 mutable std::map< mpi::MPIRank, rank_buffer_view > recv_rank_buffers_;
641 mutable std::vector< MPI_Request > data_send_requests_;
642 mutable std::vector< MPI_Request > data_recv_requests_;
643 mutable std::vector< mpi::MPIRank > recv_request_ranks_;
644 mutable int send_req_count_ = 0;
645 mutable int recv_req_count_ = 0;