From 2c79f4e52d270ff7f2be56114742d625dc3aaf11 Mon Sep 17 00:00:00 2001 From: Hartmut Kaiser Date: Tue, 7 Jan 2025 13:43:37 -0600 Subject: [PATCH] Fixing sync collectives - adding example --- libs/full/collectives/examples/CMakeLists.txt | 5 +- .../collectives/examples/distributed_pi.cpp | 47 ++++++++++ .../include/hpx/collectives/broadcast.hpp | 2 +- .../hpx/collectives/create_communicator.hpp | 18 +++- .../include/hpx/collectives/reduce.hpp | 2 +- .../collectives/src/create_communicator.cpp | 88 +++++++++++++++---- libs/full/include/include/hpx/hpx.hpp | 3 +- 7 files changed, 143 insertions(+), 22 deletions(-) create mode 100644 libs/full/collectives/examples/distributed_pi.cpp diff --git a/libs/full/collectives/examples/CMakeLists.txt b/libs/full/collectives/examples/CMakeLists.txt index 636f692d527f..1bc2db31e44a 100644 --- a/libs/full/collectives/examples/CMakeLists.txt +++ b/libs/full/collectives/examples/CMakeLists.txt @@ -17,11 +17,14 @@ else() return() endif() -set(example_programs channel_communicator) +set(example_programs channel_communicator distributed_pi) set(channel_communicator_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2) set(channel_communicator_FLAGS DEPENDENCIES iostreams_component) +set(distributed_pi_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2) +set(distributed_pi_FLAGS COMPILE_FLAGS -DHPX_HAVE_RUN_MAIN_EVERYWHERE) + foreach(example_program ${example_programs}) set(sources ${example_program}.cpp) diff --git a/libs/full/collectives/examples/distributed_pi.cpp b/libs/full/collectives/examples/distributed_pi.cpp new file mode 100644 index 000000000000..6a26703b2cca --- /dev/null +++ b/libs/full/collectives/examples/distributed_pi.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2025 Hartmut Kaiser +// +// SPDX-License-Identifier: BSL-1.0 +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#include +#include + +#include +#include +#include +#include + +inline double sqr(double val) +{ + return val * val; +} + +int main(int argc, char* argv[]) +{ + std::size_t N = 1'000'000; //'000; + std::uint32_t num_localities = hpx::get_num_localities(hpx::launch::sync); + std::uint32_t locality_id = hpx::get_locality_id(); + + if (locality_id == 0 && argc > 1) + N = std::stol(argv[1]); + + hpx::collectives::broadcast(hpx::collectives::get_world_communicator(), N); + + std::size_t const blocksize = N / num_localities; + std::size_t const begin = blocksize * locality_id; + std::size_t const end = blocksize * (locality_id + 1); + double h = 1.0 / N; + + double pi = 0.0; + for (std::size_t i = begin; i != end; ++i) + pi += h * 4.0 / (1 + sqr(i * h)); + + hpx::collectives::reduce( + hpx::collectives::get_world_communicator(), pi, std::plus{}); + + if (locality_id == 0) + std::cout << "pi: " << pi << std::endl; + + return 0; +} diff --git a/libs/full/collectives/include/hpx/collectives/broadcast.hpp b/libs/full/collectives/include/hpx/collectives/broadcast.hpp index 388a7b9af1ef..3f1c0189b1f8 100644 --- a/libs/full/collectives/include/hpx/collectives/broadcast.hpp +++ b/libs/full/collectives/include/hpx/collectives/broadcast.hpp @@ -467,7 +467,7 @@ namespace hpx::collectives { fid.wait(); // make sure communicator was created - if (this_site == fid.get_info().second) + if (this_site == std::get<2>(fid.get_info_ex())) { broadcast_to( hpx::launch::sync, HPX_MOVE(fid), value, this_site, generation); diff --git a/libs/full/collectives/include/hpx/collectives/create_communicator.hpp b/libs/full/collectives/include/hpx/collectives/create_communicator.hpp index 7aa5837354db..e34d3a048479 100644 --- a/libs/full/collectives/include/hpx/collectives/create_communicator.hpp +++ b/libs/full/collectives/include/hpx/collectives/create_communicator.hpp @@ -113,6 +113,7 @@ namespace hpx { namespace collectives { #include #include +#include #include /////////////////////////////////////////////////////////////////////////////// @@ -123,6 +124,7 @@ namespace hpx::collectives::detail { { num_sites_arg num_sites_; this_site_arg this_site_; + root_site_arg root_site_; }; } // namespace hpx::collectives::detail @@ -173,8 +175,13 @@ namespace hpx::collectives { { } - HPX_EXPORT void set_info( - num_sites_arg num_sites, this_site_arg this_site) noexcept; + HPX_EXPORT void set_info(num_sites_arg num_sites, + this_site_arg this_site, + root_site_arg root_site = root_site_arg()) noexcept; + + [[nodiscard]] HPX_EXPORT + std::tuple + get_info_ex() const noexcept; [[nodiscard]] HPX_EXPORT std::pair get_info() const noexcept; @@ -186,9 +193,14 @@ namespace hpx::collectives { }; /////////////////////////////////////////////////////////////////////////// - // Predefined global communicator + // Predefined global communicator (refers to all localities) HPX_EXPORT communicator get_world_communicator(); + /////////////////////////////////////////////////////////////////////////// + // Predefined local communicator (refers to all threads on the calling + // locality) + HPX_EXPORT communicator get_local_communicator(); + /////////////////////////////////////////////////////////////////////////// HPX_EXPORT communicator create_communicator(char const* basename, num_sites_arg num_sites = num_sites_arg(), diff --git a/libs/full/collectives/include/hpx/collectives/reduce.hpp b/libs/full/collectives/include/hpx/collectives/reduce.hpp index edc5519ff660..a6893d6aa61a 100644 --- a/libs/full/collectives/include/hpx/collectives/reduce.hpp +++ b/libs/full/collectives/include/hpx/collectives/reduce.hpp @@ -524,7 +524,7 @@ namespace hpx::collectives { fid.wait(); // make sure communicator was created - if (this_site == fid.get_info().second) + if (this_site == std::get<2>(fid.get_info_ex())) { local_result = reduce_here(hpx::launch::sync, HPX_MOVE(fid), HPX_FORWARD(T, local_result), HPX_FORWARD(F, op), this_site, diff --git a/libs/full/collectives/src/create_communicator.cpp b/libs/full/collectives/src/create_communicator.cpp index 95874e424357..d49214ef0ef6 100644 --- a/libs/full/collectives/src/create_communicator.cpp +++ b/libs/full/collectives/src/create_communicator.cpp @@ -68,29 +68,47 @@ namespace hpx::collectives { } // namespace detail /////////////////////////////////////////////////////////////////////////// - void communicator::set_info( - num_sites_arg num_sites, this_site_arg this_site) noexcept + void communicator::set_info(num_sites_arg num_sites, + this_site_arg this_site, root_site_arg root_site) noexcept { - auto& [num_sites_, this_site_] = + auto& [num_sites_, this_site_, root_site_] = get_extra_data(); num_sites_ = num_sites; this_site_ = this_site; + root_site_ = root_site; } - std::pair communicator::get_info() - const noexcept + std::pair + communicator::get_info() const noexcept { auto const* client_data = try_get_extra_data(); if (client_data != nullptr) { - return std::make_pair( - client_data->num_sites_, client_data->this_site_); + return std::make_tuple(client_data->num_sites_, + client_data->this_site_); } - return std::make_pair(num_sites_arg{}, this_site_arg{}); + return std::make_tuple( + num_sites_arg{}, this_site_arg{}); + } + + std::tuple + communicator::get_info_ex() const noexcept + { + auto const* client_data = + try_get_extra_data(); + + if (client_data != nullptr) + { + return std::make_tuple(client_data->num_sites_, + client_data->this_site_, client_data->root_site_); + } + + return std::make_tuple( + num_sites_arg{}, this_site_arg{}, root_site_arg()); } /////////////////////////////////////////////////////////////////////////// @@ -141,13 +159,17 @@ namespace hpx::collectives { "operation was already registered: {}", target.registered_name()); } - target.set_info(num_sites, this_site); + target.set_info(num_sites, this_site, root_site); return target; }); } // find existing communicator - return hpx::find_from_basename(HPX_MOVE(name), root_site); + return hpx::find_from_basename(HPX_MOVE(name), root_site) + .then(hpx::launch::sync, [=](communicator&& c) { + c.set_info(num_sites, this_site, root_site); + return c; + }); } /////////////////////////////////////////////////////////////////////////// @@ -193,31 +215,67 @@ namespace hpx::collectives { c.registered_name()); } - c.set_info(num_sites, this_site); + c.set_info(num_sites, this_site, root_site); return c; } // find existing communicator - return hpx::find_from_basename(HPX_MOVE(name), root_site); + return hpx::find_from_basename(HPX_MOVE(name), root_site) + .then(hpx::launch::sync, [=](communicator&& c) { + c.set_info(num_sites, this_site, root_site); + return c; + }); } /////////////////////////////////////////////////////////////////////////// // Predefined global communicator namespace { + communicator world_communicator; - hpx::mutex world_communicator_mtx; + communicator local_communicator; + hpx::mutex communicator_mtx; } // namespace communicator get_world_communicator() { { - std::lock_guard l(world_communicator_mtx); + std::lock_guard l(communicator_mtx); if (!world_communicator) + { + auto const num_sites = + num_sites_arg(agas::get_num_localities(hpx::launch::sync)); + auto const this_site = this_site_arg(agas::get_locality_id()); + world_communicator = - create_communicator("hpx::collectives::world_communicator"); + create_communicator("/0/world_communicator", num_sites, + this_site, generation_arg(), root_site_arg(0)); + world_communicator.set_info( + num_sites, this_site, root_site_arg(0)); + } } return world_communicator; } + + communicator get_local_communicator() + { + { + std::lock_guard l(communicator_mtx); + if (!local_communicator) + { + auto const num_sites = + num_sites_arg(hpx::get_num_worker_threads()); + auto const this_site = + this_site_arg(hpx::get_worker_thread_num()); + + local_communicator = + create_local_communicator("/local_communicator", num_sites, + this_site, generation_arg(), root_site_arg(0)); + local_communicator.set_info( + num_sites, this_site, root_site_arg(0)); + } + } + return local_communicator; + } } // namespace hpx::collectives #endif // !HPX_COMPUTE_DEVICE_CODE diff --git a/libs/full/include/include/hpx/hpx.hpp b/libs/full/include/include/hpx/hpx.hpp index 1aa815b444b3..18b5c2045cd3 100644 --- a/libs/full/include/include/hpx/hpx.hpp +++ b/libs/full/include/include/hpx/hpx.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2023 Hartmut Kaiser +// Copyright (c) 2007-2025 Hartmut Kaiser // // SPDX-License-Identifier: BSL-1.0 // Distributed under the Boost Software License, Version 1.0. (See accompanying @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include