Skip to content

Commit

Permalink
Merge pull request #6600 from STEllAR-GROUP/sync_collectives
Browse files Browse the repository at this point in the history
Fixing sync collectives
  • Loading branch information
hkaiser authored Jan 11, 2025
2 parents 09ffd2f + 1899c68 commit a04cd6f
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 48 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.examples.targets
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ tests.examples.quickstart.partitioned_vector_spmd_foreach
tests.examples.quickstart.sort_by_key_demo
tests.examples.transpose.transpose_block_numa
tests.examples.modules.collectives.distributed.tcp.channel_communicator
tests.examples.modules.collectives.distributed.tcp.distributed_pi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2007-2022 Hartmut Kaiser
// Copyright (c) 2007-2025 Hartmut Kaiser
// Copyright (c) 2011 Bryce Lelbach
//
// SPDX-License-Identifier: BSL-1.0
Expand All @@ -13,14 +13,15 @@

#include <hpx/config.hpp>

#include <cstdint>
#include <string>

namespace hpx {

///////////////////////////////////////////////////////////////////////////
/// A HPX runtime can be executed in two different modes: console mode
/// and worker mode.
enum class runtime_mode
enum class runtime_mode : std::int8_t
{
invalid = -1,
console = 0, ///< The runtime is the console locality
Expand All @@ -30,7 +31,7 @@ namespace hpx {
local = 3, ///< The runtime is fully local
default_ = 4, ///< The runtime mode will be determined
///< based on the command line arguments
last
last = default_
};

/// Get the readable string representing the name of the given runtime_mode
Expand Down
6 changes: 3 additions & 3 deletions libs/core/runtime_configuration/src/runtime_mode.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2012 Bryce Adelstein-Lelbach
// Copyright (c) 2012-2023 Hartmut Kaiser
// Copyright (c) 2012-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -28,15 +28,15 @@ namespace hpx {

char const* get_runtime_mode_name(runtime_mode state) noexcept
{
if (state < runtime_mode::invalid || state >= runtime_mode::last)
if (state < runtime_mode::invalid || state > runtime_mode::last)
return "invalid (value out of bounds)";
return strings::runtime_mode_names[static_cast<int>(state) + 1];
}

runtime_mode get_runtime_mode_from_name(std::string const& mode)
{
for (std::size_t i = 0;
static_cast<runtime_mode>(i) < runtime_mode::last; ++i)
static_cast<runtime_mode>(i) <= runtime_mode::last; ++i)
{
if (mode == strings::runtime_mode_names[i])
return static_cast<runtime_mode>(i - 1);
Expand Down
5 changes: 4 additions & 1 deletion libs/full/collectives/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ else()
return()
endif()

set(example_programs channel_communicator)
set(example_programs channel_communicator distributed_pi)

set(channel_communicator_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2)
set(channel_communicator_FLAGS DEPENDENCIES iostreams_component)

set(distributed_pi_PARAMETERS LOCALITIES 2 THREADS_PER_LOCALITY 2)
set(distributed_pi_FLAGS COMPILE_FLAGS -DHPX_HAVE_RUN_MAIN_EVERYWHERE)

foreach(example_program ${example_programs})

set(sources ${example_program}.cpp)
Expand Down
58 changes: 58 additions & 0 deletions libs/full/collectives/examples/distributed_pi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <hpx/hpx_main.hpp>

#if !defined(HPX_COMPUTE_DEVICE_CODE)
#include <hpx/hpx.hpp>

#include <cstddef>
#include <cstdint>
#include <iostream>
#include <string>

inline double sqr(double val)
{
return val * val;
}

int main(int argc, char* argv[])
{
std::size_t N = 1'000'000;
std::uint32_t num_localities = hpx::get_num_localities(hpx::launch::sync);
std::uint32_t locality_id = hpx::get_locality_id();

if (locality_id == 0 && argc > 1)
N = std::stol(argv[1]);

hpx::collectives::broadcast(hpx::collectives::get_world_communicator(), N);

std::size_t const blocksize = N / num_localities;
std::size_t const begin = blocksize * locality_id;
std::size_t const end = blocksize * (locality_id + 1);
double h = 1.0 / N;

double pi = 0.0;
for (std::size_t i = begin; i != end; ++i)
pi += h * 4.0 / (1 + sqr(i * h));

hpx::collectives::reduce(
hpx::collectives::get_world_communicator(), pi, std::plus{});

if (locality_id == 0)
std::cout << "pi: " << pi << std::endl;

return 0;
}

#else

int main(int argc, char* argv[])
{
return 0;
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ namespace hpx::collectives {

fid.wait(); // make sure communicator was created

if (this_site == fid.get_info().second)
if (this_site == std::get<2>(fid.get_info_ex()))
{
broadcast_to(
hpx::launch::sync, HPX_MOVE(fid), value, this_site, generation);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Hartmut Kaiser
// Copyright (c) 2021-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -102,8 +102,6 @@ namespace hpx { namespace collectives {

#else

#include <hpx/config.hpp>

#if !defined(HPX_COMPUTE_DEVICE_CODE)
#include <hpx/async_base/launch_policy.hpp>
#include <hpx/async_distributed/async.hpp>
Expand All @@ -115,9 +113,8 @@ namespace hpx { namespace collectives {
#include <cstddef>
#include <memory>
#include <utility>
#include <vector>

namespace hpx { namespace collectives {
namespace hpx::collectives {

// forward declarations
class channel_communicator;
Expand All @@ -126,10 +123,18 @@ namespace hpx { namespace collectives {
hpx::future<T> get(
channel_communicator, that_site_arg, tag_arg = tag_arg());

template <typename T>
T get(hpx::launch::sync_policy, channel_communicator, that_site_arg,
tag_arg = tag_arg());

template <typename T>
hpx::future<void> set(
channel_communicator, that_site_arg, T&&, tag_arg = tag_arg());

template <typename T>
void set(hpx::launch::sync_policy, channel_communicator, that_site_arg, T&&,
tag_arg = tag_arg());

class channel_communicator
{
private:
Expand All @@ -140,10 +145,18 @@ namespace hpx { namespace collectives {
template <typename T>
friend hpx::future<T> get(channel_communicator, that_site_arg, tag_arg);

template <typename T>
friend T get(hpx::launch::sync_policy, channel_communicator,
that_site_arg, tag_arg);

template <typename T>
friend hpx::future<void> set(
channel_communicator, that_site_arg, T&&, tag_arg);

template <typename T>
friend void set(hpx::launch::sync_policy, channel_communicator,
that_site_arg, T&&, tag_arg);

private:
HPX_EXPORT channel_communicator(char const* basename,
num_sites_arg num_sites, this_site_arg this_site,
Expand All @@ -163,6 +176,11 @@ namespace hpx { namespace collectives {

HPX_EXPORT void free();

explicit operator bool() const noexcept
{
return comm_.get() != nullptr;
}

private:
std::shared_ptr<detail::channel_communicator> comm_;
};
Expand All @@ -185,14 +203,41 @@ namespace hpx { namespace collectives {
return comm.comm_->template get<T>(site.argument_, tag.argument_);
}

template <typename T>
T get(hpx::launch::sync_policy, channel_communicator comm,
that_site_arg site, tag_arg tag)
{
return comm.comm_->template get<T>(site.argument_, tag.argument_).get();
}

///////////////////////////////////////////////////////////////////////////
template <typename T>
hpx::future<void> set(
channel_communicator comm, that_site_arg site, T&& value, tag_arg tag)
{
return comm.comm_->set(
site.argument_, HPX_FORWARD(T, value), tag.argument_);
}
}} // namespace hpx::collectives

template <typename T>
void set(hpx::launch::sync_policy, channel_communicator comm,
that_site_arg site, T&& value, tag_arg tag)
{
return comm.comm_
->set(site.argument_, HPX_FORWARD(T, value), tag.argument_)
.get();
}

///////////////////////////////////////////////////////////////////////////
// Predefined p2p communicator (refers to all localities)
HPX_EXPORT channel_communicator get_world_channel_communicator();

namespace detail {

HPX_EXPORT void create_world_channel_communicator();
HPX_EXPORT void reset_world_channel_communicator();
} // namespace detail
} // namespace hpx::collectives

#endif // !HPX_COMPUTE_DEVICE_CODE
#endif // DOXYGEN
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2023 Hartmut Kaiser
// Copyright (c) 2020-2025 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
Expand Down Expand Up @@ -105,14 +105,13 @@ namespace hpx { namespace collectives {

#else

#include <hpx/config.hpp>

#if !defined(HPX_COMPUTE_DEVICE_CODE)
#include <hpx/collectives/argument_types.hpp>
#include <hpx/collectives/detail/communicator.hpp>
#include <hpx/components/client_base.hpp>
#include <hpx/type_support/extra_data.hpp>

#include <tuple>
#include <utility>

///////////////////////////////////////////////////////////////////////////////
Expand All @@ -123,6 +122,7 @@ namespace hpx::collectives::detail {
{
num_sites_arg num_sites_;
this_site_arg this_site_;
root_site_arg root_site_;
};
} // namespace hpx::collectives::detail

Expand Down Expand Up @@ -173,8 +173,13 @@ namespace hpx::collectives {
{
}

HPX_EXPORT void set_info(
num_sites_arg num_sites, this_site_arg this_site) noexcept;
HPX_EXPORT void set_info(num_sites_arg num_sites,
this_site_arg this_site,
root_site_arg root_site = root_site_arg()) noexcept;

[[nodiscard]] HPX_EXPORT
std::tuple<num_sites_arg, this_site_arg, root_site_arg>
get_info_ex() const noexcept;

[[nodiscard]] HPX_EXPORT std::pair<num_sites_arg, this_site_arg>
get_info() const noexcept;
Expand All @@ -186,9 +191,26 @@ namespace hpx::collectives {
};

///////////////////////////////////////////////////////////////////////////
// Predefined global communicator
// Predefined global communicator (refers to all localities)
HPX_EXPORT communicator get_world_communicator();

namespace detail {

HPX_EXPORT void create_global_communicator();
HPX_EXPORT void reset_global_communicator();
} // namespace detail

///////////////////////////////////////////////////////////////////////////
// Predefined local communicator (refers to all threads on the calling
// locality)
HPX_EXPORT communicator get_local_communicator();

namespace detail {

HPX_EXPORT void create_local_communicator();
HPX_EXPORT void reset_local_communicator();
} // namespace detail

///////////////////////////////////////////////////////////////////////////
HPX_EXPORT communicator create_communicator(char const* basename,
num_sites_arg num_sites = num_sites_arg(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
#include <hpx/lcos_local/channel.hpp>
#include <hpx/lock_registration/detail/register_locks.hpp>
#include <hpx/synchronization/spinlock.hpp>
#include <hpx/type_support/unused.hpp>

#include <cstddef>
#include <map>
#include <mutex>
#include <utility>
#include <vector>

namespace hpx { namespace collectives { namespace detail {
namespace hpx::collectives::detail {

///////////////////////////////////////////////////////////////////////////
class channel_communicator_server
Expand All @@ -39,7 +38,6 @@ namespace hpx { namespace collectives { namespace detail {

public:
channel_communicator_server() //-V730
: data_()
{
HPX_ASSERT(false); // shouldn't ever be called
}
Expand All @@ -57,8 +55,7 @@ namespace hpx { namespace collectives { namespace detail {

{
std::unique_lock l(data_[which].mtx_);
util::ignore_while_checking il(&l);
HPX_UNUSED(il);
[[maybe_unused]] util::ignore_while_checking il(&l);

channel_type& c = data_[which].channels_[tag];
f = c.get();
Expand All @@ -84,8 +81,7 @@ namespace hpx { namespace collectives { namespace detail {
void set(std::size_t which, T value, std::size_t tag)
{
std::unique_lock l(data_[which].mtx_);
util::ignore_while_checking il(&l);
HPX_UNUSED(il);
[[maybe_unused]] util::ignore_while_checking il(&l);

data_[which].channels_[tag].set(unique_any_nonser(HPX_MOVE(value)));
}
Expand Down Expand Up @@ -157,6 +153,6 @@ namespace hpx { namespace collectives { namespace detail {
std::size_t this_site_;
std::vector<client_type> clients_;
};
}}} // namespace hpx::collectives::detail
} // namespace hpx::collectives::detail

#endif // COMPUTE_HOST_CODE
Loading

0 comments on commit a04cd6f

Please sign in to comment.