Skip to content

Commit

Permalink
Fixing handling of bool value type for collective operations
Browse files Browse the repository at this point in the history
  • Loading branch information
hkaiser committed May 14, 2024
1 parent ba8f05f commit 368ea21
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 8 deletions.
7 changes: 5 additions & 2 deletions libs/full/collectives/include/hpx/collectives/all_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,11 @@ namespace hpx::traits {
{
// compute reduction result only once
auto it = data.begin();
data[0] = hpx::reduce(
++it, data.end(), data[0], HPX_FORWARD(F, op));
data[0] = Communicator::template handle_bool<
std::decay_t<T>>(hpx::reduce(++it, data.end(),
Communicator::template handle_bool<std::decay_t<T>>(
data[0]),
HPX_FORWARD(F, op)));
data_available = true;
}
return Communicator::template handle_bool<std::decay_t<T>>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ namespace hpx::traits {

// first value is not taken into account
auto it = data.begin();
hpx::exclusive_scan(it, data.end(), dest.begin(), *it,
hpx::exclusive_scan(it, data.end(), dest.begin(),
Communicator::template handle_bool<std::decay_t<T>>(
*it),
HPX_FORWARD(F, op));

std::swap(data, dest);
Expand Down
13 changes: 8 additions & 5 deletions libs/full/collectives/include/hpx/collectives/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,22 @@ namespace hpx::traits {
communication::reduce_tag>::name(),
which, generation,
// step function (invoked once for get)
[&t](auto& data, std::size_t which) {
data[which] = HPX_FORWARD(T, t);
[&t](auto& data, std::size_t site) {
data[site] = HPX_FORWARD(T, t);
},
// finalizer (invoked once after all data has been received)
[op = HPX_FORWARD(F, op)](
auto& data, bool&, std::size_t) mutable {
HPX_ASSERT(!data.empty());

if (data.size() > 1)
{
auto it = data.begin();
return Communicator::template handle_bool<
std::decay_t<T>>(hpx::reduce(++it, data.end(),
HPX_MOVE(data[0]), HPX_FORWARD(F, op)));
Communicator::template handle_bool<std::decay_t<T>>(
HPX_MOVE(data[0])),
HPX_FORWARD(F, op)));
}
return Communicator::template handle_bool<std::decay_t<T>>(
HPX_MOVE(data[0]));
Expand All @@ -285,8 +288,8 @@ namespace hpx::traits {
communication::reduce_tag>::name(),
which, generation,
// step function (invoked for each set)
[t = HPX_FORWARD(T, t)](auto& data, std::size_t which) mutable {
data[which] = HPX_FORWARD(T, t);
[t = HPX_FORWARD(T, t)](auto& data, std::size_t site) mutable {
data[site] = HPX_FORWARD(T, t);
},
// no finalizer
nullptr);
Expand Down
2 changes: 2 additions & 0 deletions libs/full/collectives/tests/regressions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ if(HPX_WITH_NETWORKING)
broadcast_wait_for_2822
collectives_bool_5940
multiple_gather_ops_2001
reduce_vector_bool
trivially_copyable_all_gather
)

Expand All @@ -22,6 +23,7 @@ if(HPX_WITH_NETWORKING)
set(barrier_3792_PARAMETERS LOCALITIES 3 THREADS_PER_LOCALITY 1)
set(collectives_bool_5940_PARAMETERS LOCALITIES 2)
set(multiple_gather_ops_2001_PARAMETERS LOCALITIES 2)
set(reduce_vector_bool_2001_PARAMETERS LOCALITIES 2)

foreach(test ${tests})
set(sources ${test}.cpp)
Expand Down
157 changes: 157 additions & 0 deletions libs/full/collectives/tests/regressions/reduce_vector_bool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright (c) 2019-2024 Hartmut Kaiser
//
// SPDX-License-Identifier: BSL-1.0
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

#include <hpx/config.hpp>

#if !defined(HPX_COMPUTE_DEVICE_CODE)
#include <hpx/hpx.hpp>
#include <hpx/hpx_init.hpp>
#include <hpx/modules/collectives.hpp>
#include <hpx/modules/testing.hpp>

#include <cstdint>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

using namespace hpx::collectives;

constexpr char const* reduce_direct_basename = "/test/reduce_direct/";
#if defined(HPX_DEBUG)
constexpr int ITERATIONS = 100;
#else
constexpr int ITERATIONS = 1000;
#endif

void test_multiple_use_with_generation()
{
std::uint32_t const this_locality = hpx::get_locality_id();
std::uint32_t const num_localities =
hpx::get_num_localities(hpx::launch::sync);
HPX_TEST_LTE(static_cast<std::uint32_t>(2), num_localities);

auto const reduce_direct_client =
create_communicator(reduce_direct_basename,
num_sites_arg(num_localities), this_site_arg(this_locality));

hpx::chrono::high_resolution_timer const t;

for (int i = 0; i != ITERATIONS; ++i)
{
bool value = ((this_locality + i) % 2) ? true : false;
if (this_locality == 0)
{
hpx::future<bool> overall_result =
reduce_here(reduce_direct_client, std::move(value),
std::logical_or<>{}, generation_arg(i + 1));

bool sum = false;
for (std::uint32_t j = 0; j != num_localities; ++j)
{
sum = sum || (((j + i) % 2) ? true : false);
}
HPX_TEST_EQ(sum, overall_result.get());
}
else
{
hpx::future<void> overall_result = reduce_there(
reduce_direct_client, std::move(value), generation_arg(i + 1));
overall_result.get();
}
}

auto const elapsed = t.elapsed();
if (this_locality == 0)
{
std::cout << "remote timing: " << elapsed / ITERATIONS << "[s]\n";
}
}

void test_local_use()
{
constexpr std::uint32_t num_sites = 10;

std::vector<hpx::future<void>> sites;
sites.reserve(num_sites);

// launch num_sites threads to represent different sites
for (std::uint32_t site = 0; site != num_sites; ++site)
{
sites.push_back(hpx::async([=]() {
auto const reduce_direct_client =
create_communicator(reduce_direct_basename,
num_sites_arg(num_sites), this_site_arg(site));

hpx::chrono::high_resolution_timer const t;

// test functionality based on immediate local result value
for (int i = 0; i != ITERATIONS; ++i)
{
bool value = ((site + i) % 2) ? true : false;
if (site == 0)
{
hpx::future<bool> overall_result = reduce_here(
reduce_direct_client, std::move(value), std::logical_or<>{},
generation_arg(i + 1), this_site_arg(site));

bool sum = false;
for (std::uint32_t j = 0; j != num_sites; ++j)
{
sum = sum || (((j + i) % 2) ? true : false);
}
HPX_TEST_EQ(sum, overall_result.get());
}
else
{
hpx::future<void> overall_result =
reduce_there(reduce_direct_client, std::move(value),
generation_arg(i + 1), this_site_arg(site));
overall_result.get();
}
}

auto const elapsed = t.elapsed();
if (site == 0)
{
std::cout << "local timing: " << elapsed / (10 * ITERATIONS)
<< "[s]\n";
}
}));
}

hpx::wait_all(std::move(sites));
}

int hpx_main()
{
#if defined(HPX_HAVE_NETWORKING)
if (hpx::get_num_localities(hpx::launch::sync) > 1)
{
test_multiple_use_with_generation();
}
#endif

if (hpx::get_locality_id() == 0)
{
test_local_use();
}

return hpx::finalize();
}

int main(int argc, char* argv[])
{
std::vector<std::string> const cfg = {"hpx.run_hpx_main!=1"};

hpx::init_params init_args;
init_args.cfg = cfg;

HPX_TEST_EQ(hpx::init(argc, argv, init_args), 0);
return hpx::util::report_errors();
}

#endif

0 comments on commit 368ea21

Please sign in to comment.