43 static constexpr int VecDim = grid::grid_data_vec_dim< GridDataType >();
45 using buffer_view = Kokkos::View< ScalarType*, memory_space >;
58 : domain_fine_( &domain_fine )
59 , domain_coarse_( &domain_coarse )
60 , union_comm_( domain_fine.comm() )
62 build_plan_( subdomain_to_rank_fine, subdomain_to_rank_coarse );
73 void apply(
const GridDataType& src_fine, GridDataType& dst_coarse )
75 if ( union_comm_ == MPI_COMM_NULL )
return;
76 if ( identity_plan_ )
return;
78 run_alltoallv_( src_fine, dst_coarse,
false );
86 if ( union_comm_ == MPI_COMM_NULL )
return;
87 if ( identity_plan_ )
return;
89 run_alltoallv_( src_coarse, dst_fine,
true );
102 int subdomain_block_size_()
const
105 const int ni = Ni_( *domain_fine_ );
106 const int nr = Nr_( *domain_fine_ );
107 return ni * ni * nr *
VecDim;
115 int local_subdomain_on_fine;
116 int local_subdomain_on_coarse;
128 if ( union_comm_ == MPI_COMM_NULL )
131 const auto& dom_info = domain_fine_->
domain_info();
133 const int n_rad = dom_info.num_subdomains_in_radial_direction();
138 MPI_Group union_group = MPI_GROUP_NULL;
139 MPI_Group coarse_group = MPI_GROUP_NULL;
140 MPI_Comm_group( union_comm_, &union_group );
142 const MPI_Comm coarse_comm = domain_coarse_->
comm();
143 const bool have_coarse = ( coarse_comm != MPI_COMM_NULL );
146 MPI_Comm_group( coarse_comm, &coarse_group );
152 MPI_Group_size( coarse_group, &coarse_size );
155 MPI_Allreduce( MPI_IN_PLACE, &coarse_size, 1, MPI_INT, MPI_MAX, union_comm_ );
161 std::vector< int > coarse_to_union( coarse_size, MPI_PROC_NULL );
164 std::vector< int > coarse_ranks( coarse_size );
165 std::iota( coarse_ranks.begin(), coarse_ranks.end(), 0 );
166 MPI_Group_translate_ranks( coarse_group, coarse_size, coarse_ranks.data(),
167 union_group, coarse_to_union.data() );
173 int bcaster_union_rank = std::numeric_limits< int >::max();
176 int my_coarse_rank =
mpi::rank( coarse_comm );
177 if ( my_coarse_rank == 0 )
179 bcaster_union_rank =
mpi::rank( union_comm_ );
182 MPI_Allreduce( MPI_IN_PLACE, &bcaster_union_rank, 1, MPI_INT, MPI_MIN, union_comm_ );
183 MPI_Bcast( coarse_to_union.data(), coarse_size, MPI_INT, bcaster_union_rank, union_comm_ );
185 coarse_to_union_ = std::move( coarse_to_union );
187 if ( union_group != MPI_GROUP_NULL ) MPI_Group_free( &union_group );
188 if ( coarse_group != MPI_GROUP_NULL ) MPI_Group_free( &coarse_group );
194 std::map< grid::shell::SubdomainInfo, int > fine_local_idx;
195 for (
const auto& [sdr, info] : domain_fine_->subdomains() )
197 fine_local_idx[sdr] = std::get< 0 >( info );
199 std::map< grid::shell::SubdomainInfo, int > coarse_local_idx;
200 for (
const auto& [sdr, info] : domain_coarse_->subdomains() )
202 coarse_local_idx[sdr] = std::get< 0 >( info );
205 const int my_union_rank =
mpi::rank( union_comm_ );
206 const int block = subdomain_block_size_();
213 MPI_Comm_size( union_comm_, &union_size );
214 send_counts_.assign( union_size, 0 );
215 recv_counts_.assign( union_size, 0 );
217 send_messages_.clear();
218 recv_messages_.clear();
220 for (
int diamond_id = 0; diamond_id < 10; ++diamond_id )
222 for (
int x = 0; x < n_diam; ++x )
224 for (
int y = 0; y < n_diam; ++y )
226 for (
int r = 0;
r < n_rad; ++
r )
228 grid::shell::SubdomainInfo s( diamond_id, x, y, r );
230 const int fine_owner = fn_fine( s, n_diam, n_rad );
231 const int coarse_owner = fn_coarse( s, n_diam, n_rad );
232 if ( coarse_owner < 0 || coarse_owner >= coarse_size )
234 throw std::runtime_error(
235 "Redistribute: coarse subdomain_to_rank produced out-of-range rank" );
237 const int recv_union = coarse_to_union_[coarse_owner];
239 const bool i_send = ( fine_owner == my_union_rank );
240 const bool i_recv = ( recv_union == my_union_rank );
245 m.send_rank = my_union_rank;
246 m.recv_rank = recv_union;
247 m.local_subdomain_on_fine = fine_local_idx.at( s );
248 m.local_subdomain_on_coarse = -1;
249 send_messages_.push_back( m );
250 send_counts_[recv_union] += block;
255 m.send_rank = fine_owner;
256 m.recv_rank = my_union_rank;
257 m.local_subdomain_on_fine = -1;
258 m.local_subdomain_on_coarse = coarse_local_idx.at( s );
259 recv_messages_.push_back( m );
260 recv_counts_[fine_owner] += block;
268 auto by_peer_and_subdomain = [](
const Message& a,
const Message& b ) {
269 if ( a.send_rank != b.send_rank )
return a.send_rank < b.send_rank;
270 if ( a.recv_rank != b.recv_rank )
return a.recv_rank < b.recv_rank;
272 const int ai = a.local_subdomain_on_fine >= 0 ? a.local_subdomain_on_fine : a.local_subdomain_on_coarse;
273 const int bi = b.local_subdomain_on_fine >= 0 ? b.local_subdomain_on_fine : b.local_subdomain_on_coarse;
289 std::stable_sort( send_messages_.begin(), send_messages_.end(),
290 [](
const Message& a,
const Message& b ) { return a.recv_rank < b.recv_rank; } );
291 std::stable_sort( recv_messages_.begin(), recv_messages_.end(),
292 [](
const Message& a,
const Message& b ) { return a.send_rank < b.send_rank; } );
295 send_displs_.assign( union_size, 0 );
296 recv_displs_.assign( union_size, 0 );
297 for (
int r = 1;
r < union_size; ++
r )
299 send_displs_[
r] = send_displs_[
r - 1] + send_counts_[
r - 1];
300 recv_displs_[
r] = recv_displs_[
r - 1] + recv_counts_[
r - 1];
302 const int total_send = send_displs_.back() + send_counts_.back();
303 const int total_recv = recv_displs_.back() + recv_counts_.back();
308 const int send_alloc = std::max( total_send, 1 );
309 const int recv_alloc = std::max( total_recv, 1 );
310 send_buf_ =
buffer_view( Kokkos::view_alloc( Kokkos::WithoutInitializing,
"redistribute_send" ), send_alloc );
311 recv_buf_ =
buffer_view( Kokkos::view_alloc( Kokkos::WithoutInitializing,
"redistribute_recv" ), recv_alloc );
314 send_host_ =
host_buffer_view( Kokkos::view_alloc( Kokkos::WithoutInitializing,
"redistribute_send_host" ),
316 recv_host_ =
host_buffer_view( Kokkos::view_alloc( Kokkos::WithoutInitializing,
"redistribute_recv_host" ),
326 int local_identity = 1;
327 for (
const auto& m : send_messages_ )
328 if ( m.send_rank != m.recv_rank ) { local_identity = 0;
break; }
329 if ( local_identity )
330 for (
const auto& m : recv_messages_ )
331 if ( m.send_rank != m.recv_rank ) { local_identity = 0;
break; }
332 MPI_Allreduce( MPI_IN_PLACE, &local_identity, 1, MPI_INT, MPI_LAND, union_comm_ );
333 identity_plan_ = ( local_identity != 0 );
340 void pack_(
const GridDataType& src,
const buffer_view& buf,
const std::vector< Message >& messages,
341 bool use_fine_index )
const
343 const int ni = Ni_( *domain_fine_ );
344 const int nr = Nr_( *domain_fine_ );
345 const int block = subdomain_block_size_();
348 Kokkos::View< int*, memory_space > msg_sdr( Kokkos::view_alloc( Kokkos::WithoutInitializing,
"msg_sdr" ),
350 auto msg_sdr_host = Kokkos::create_mirror_view( msg_sdr );
351 for (
size_t i = 0; i < messages.size(); ++i )
354 use_fine_index ? messages[i].local_subdomain_on_fine : messages[i].local_subdomain_on_coarse;
356 Kokkos::deep_copy( msg_sdr, msg_sdr_host );
358 const int num_messages =
static_cast< int >( messages.size() );
359 if ( num_messages == 0 )
return;
363 if constexpr (
VecDim == 1 )
365 Kokkos::parallel_for(
367 Kokkos::MDRangePolicy< Kokkos::Rank< 5 > >( { 0, 0, 0, 0, 0 }, { num_messages, ni, ni, nr,
VecDim } ),
368 KOKKOS_LAMBDA(
int m,
int i,
int j,
int k,
int c ) {
369 const int local_sdr = msg_sdr( m );
370 const int flat = m * block + ( ( i * ni + j ) * nr + k ) *
VecDim + c;
371 buf( flat ) = src( local_sdr, i, j, k );
376 Kokkos::parallel_for(
378 Kokkos::MDRangePolicy< Kokkos::Rank< 5 > >( { 0, 0, 0, 0, 0 }, { num_messages, ni, ni, nr,
VecDim } ),
379 KOKKOS_LAMBDA(
int m,
int i,
int j,
int k,
int c ) {
380 const int local_sdr = msg_sdr( m );
381 const int flat = m * block + ( ( i * ni + j ) * nr + k ) *
VecDim + c;
382 buf( flat ) = src( local_sdr, i, j, k, c );
385 Kokkos::fence(
"redistribute_pack" );
389 bool use_fine_index )
const
391 const int ni = Ni_( *domain_fine_ );
392 const int nr = Nr_( *domain_fine_ );
393 const int block = subdomain_block_size_();
395 Kokkos::View< int*, memory_space > msg_sdr( Kokkos::view_alloc( Kokkos::WithoutInitializing,
"msg_sdr" ),
397 auto msg_sdr_host = Kokkos::create_mirror_view( msg_sdr );
398 for (
size_t i = 0; i < messages.size(); ++i )
401 use_fine_index ? messages[i].local_subdomain_on_fine : messages[i].local_subdomain_on_coarse;
403 Kokkos::deep_copy( msg_sdr, msg_sdr_host );
405 const int num_messages =
static_cast< int >( messages.size() );
406 if ( num_messages == 0 )
return;
408 if constexpr (
VecDim == 1 )
410 Kokkos::parallel_for(
411 "redistribute_unpack",
412 Kokkos::MDRangePolicy< Kokkos::Rank< 5 > >( { 0, 0, 0, 0, 0 }, { num_messages, ni, ni, nr,
VecDim } ),
413 KOKKOS_LAMBDA(
int m,
int i,
int j,
int k,
int c ) {
414 const int local_sdr = msg_sdr( m );
415 const int flat = m * block + ( ( i * ni + j ) * nr + k ) *
VecDim + c;
416 dst( local_sdr, i, j, k ) = buf( flat );
421 Kokkos::parallel_for(
422 "redistribute_unpack",
423 Kokkos::MDRangePolicy< Kokkos::Rank< 5 > >( { 0, 0, 0, 0, 0 }, { num_messages, ni, ni, nr,
VecDim } ),
424 KOKKOS_LAMBDA(
int m,
int i,
int j,
int k,
int c ) {
425 const int local_sdr = msg_sdr( m );
426 const int flat = m * block + ( ( i * ni + j ) * nr + k ) *
VecDim + c;
427 dst( local_sdr, i, j, k, c ) = buf( flat );
430 Kokkos::fence(
"redistribute_unpack" );
434 void run_alltoallv_(
const GridDataType& src, GridDataType& dst,
bool transpose )
438 const std::vector< Message >& pack_msgs = transpose ? recv_messages_ : send_messages_;
439 const std::vector< Message >& unpack_msgs = transpose ? send_messages_ : recv_messages_;
440 const std::vector< int >& s_counts = transpose ? recv_counts_ : send_counts_;
441 const std::vector< int >& s_displs = transpose ? recv_displs_ : send_displs_;
442 const std::vector< int >& r_counts = transpose ? send_counts_ : recv_counts_;
443 const std::vector< int >& r_displs = transpose ? send_displs_ : recv_displs_;
445 const bool pack_uses_fine_idx = !transpose;
446 const bool unpack_uses_fine_idx = transpose;
449 buffer_view& device_send = transpose ? recv_buf_ : send_buf_;
450 buffer_view& device_recv = transpose ? send_buf_ : recv_buf_;
454 pack_( src, device_send, pack_msgs, pack_uses_fine_idx );
457 Kokkos::deep_copy( host_send, device_send );
459 MPI_Alltoallv( host_send.data(),
462 mpi::mpi_datatype< ScalarType >(),
466 mpi::mpi_datatype< ScalarType >(),
470 Kokkos::deep_copy( device_recv, host_recv );
472 unpack_( dst, device_recv, unpack_msgs, unpack_uses_fine_idx );
477 MPI_Comm union_comm_ = MPI_COMM_NULL;
479 std::vector< int > coarse_to_union_;
481 std::vector< Message > send_messages_;
482 std::vector< Message > recv_messages_;
484 std::vector< int > send_counts_;
485 std::vector< int > send_displs_;
486 std::vector< int > recv_counts_;
487 std::vector< int > recv_displs_;
497 bool identity_plan_ =
false;