Skip to content

Commit

Permalink
Device reduce optimization (#637)
Browse files Browse the repository at this point in the history
* Reduce amount of items per thread when the size is larger than necessary

* Fit blocks that are just to small

* Cleanup code device_reduce

* Added some comments

* Fix merging develop
  • Loading branch information
NB4444 authored Nov 11, 2024
1 parent c9c9a12 commit f37b344
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 163 deletions.
15 changes: 13 additions & 2 deletions rocprim/include/rocprim/device/config_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ enum class target_arch : unsigned int
gfx906 = 906,
gfx908 = 908,
gfx90a = 910,
gfx942 = 942,
gfx1030 = 1030,
gfx1100 = 1100,
gfx1102 = 1102,
Expand Down Expand Up @@ -203,14 +204,22 @@ constexpr bool prefix_equals(const char* lhs, const char* rhs, std::size_t n)

constexpr target_arch get_target_arch_from_name(const char* const arch_name, const std::size_t n)
{
constexpr const char* target_names[]
= {"gfx803", "gfx900", "gfx906", "gfx908", "gfx90a", "gfx1030", "gfx1100", "gfx1102"};
constexpr const char* target_names[] = {"gfx803",
"gfx900",
"gfx906",
"gfx908",
"gfx90a",
"gfx942",
"gfx1030",
"gfx1100",
"gfx1102"};
constexpr target_arch target_architectures[] = {
target_arch::gfx803,
target_arch::gfx900,
target_arch::gfx906,
target_arch::gfx908,
target_arch::gfx90a,
target_arch::gfx942,
target_arch::gfx1030,
target_arch::gfx1100,
target_arch::gfx1102,
Expand Down Expand Up @@ -266,6 +275,8 @@ auto dispatch_target_arch(const target_arch target_arch)
return Config::template architecture_config<target_arch::gfx908>::params;
case target_arch::gfx90a:
return Config::template architecture_config<target_arch::gfx90a>::params;
case target_arch::gfx942:
return Config::template architecture_config<target_arch::gfx942>::params;
case target_arch::gfx1030:
return Config::template architecture_config<target_arch::gfx1030>::params;
case target_arch::gfx1100:
Expand Down
68 changes: 68 additions & 0 deletions rocprim/include/rocprim/device/detail/config/device_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,74 @@ struct default_reduce_config<static_cast<unsigned int>(target_arch::gfx90a),
: reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = double
template<class key_type>
struct default_reduce_config<
static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(bool(rocprim::is_floating_point<key_type>::value) && (sizeof(key_type) <= 8)
&& (sizeof(key_type) > 4))>>
: reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = float
template<class key_type>
struct default_reduce_config<
static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(bool(rocprim::is_floating_point<key_type>::value) && (sizeof(key_type) <= 4)
&& (sizeof(key_type) > 2))>>
: reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = rocprim::half
template<class key_type>
struct default_reduce_config<static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(bool(rocprim::is_floating_point<key_type>::value)
&& (sizeof(key_type) <= 2))>>
: reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = int64_t
template<class key_type>
struct default_reduce_config<
static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(!bool(rocprim::is_floating_point<key_type>::value) && (sizeof(key_type) <= 8)
&& (sizeof(key_type) > 4))>>
: reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = int
template<class key_type>
struct default_reduce_config<
static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(!bool(rocprim::is_floating_point<key_type>::value) && (sizeof(key_type) <= 4)
&& (sizeof(key_type) > 2))>>
: reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = short
template<class key_type>
struct default_reduce_config<
static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(!bool(rocprim::is_floating_point<key_type>::value) && (sizeof(key_type) <= 2)
&& (sizeof(key_type) > 1))>>
: reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

// Based on key_type = int8_t
template<class key_type>
struct default_reduce_config<static_cast<unsigned int>(target_arch::gfx942),
key_type,
std::enable_if_t<(!bool(rocprim::is_floating_point<key_type>::value)
&& (sizeof(key_type) <= 1))>>
: reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce>
{};

} // end namespace detail

END_ROCPRIM_NAMESPACE
Expand Down
86 changes: 34 additions & 52 deletions rocprim/include/rocprim/device/detail/device_reduce.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand All @@ -21,61 +21,51 @@
#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_
#define ROCPRIM_DEVICE_DETAIL_DEVICE_REDUCE_HPP_

#include <type_traits>
#include <iterator>

#include "../../config.hpp"
#include "../../detail/temp_storage.hpp"
#include "../../detail/various.hpp"
#include "../config_types.hpp"
#include "../device_reduce_config.hpp"

#include "../../intrinsics.hpp"
#include "../../functional.hpp"
#include "../../intrinsics.hpp"
#include "../../types.hpp"

#include "../../block/block_load.hpp"
#include "../../block/block_reduce.hpp"

#include <iterator>
#include <type_traits>

BEGIN_ROCPRIM_NAMESPACE

namespace detail
{

// Helper functions for reducing final value with
// initial value.
template<
bool WithInitialValue,
class T,
class BinaryFunction
>
template<bool WithInitialValue, class T, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto reduce_with_initial(T output,
T initial_value,
BinaryFunction reduce_op)
-> typename std::enable_if<WithInitialValue, T>::type
auto reduce_with_initial(T output, T initial_value, BinaryFunction reduce_op) ->
typename std::enable_if<WithInitialValue, T>::type
{
return reduce_op(initial_value, output);
}

template<
bool WithInitialValue,
class T,
class BinaryFunction
>
template<bool WithInitialValue, class T, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE
auto reduce_with_initial(T output,
T initial_value,
BinaryFunction reduce_op)
-> typename std::enable_if<!WithInitialValue, T>::type
auto reduce_with_initial(T output, T initial_value, BinaryFunction reduce_op) ->
typename std::enable_if<!WithInitialValue, T>::type
{
(void) initial_value;
(void) reduce_op;
(void)initial_value;
(void)reduce_op;
return output;
}

template<
bool WithInitialValue,
bool FitLarger,
unsigned int FitItems,
class Config,
class ResultType,
class InputIterator,
Expand All @@ -92,8 +82,10 @@ void block_reduce_kernel_impl(InputIterator input,
{
static constexpr reduce_config_params params = device_params<Config>();

constexpr unsigned int block_size = params.reduce_config.block_size;
constexpr unsigned int items_per_thread = params.reduce_config.items_per_thread;
constexpr unsigned int block_size = params.reduce_config.block_size;
constexpr unsigned int items_per_thread
= FitLarger ? params.reduce_config.items_per_thread * FitItems
: ceiling_div(params.reduce_config.items_per_thread, FitItems);

using result_type = ResultType;

Expand All @@ -111,12 +103,10 @@ void block_reduce_kernel_impl(InputIterator input,
// last incomplete block
if(flat_block_id == (input_size / items_per_block))
{
block_load_direct_striped<block_size>(
flat_id,
input + block_offset,
values,
valid_in_last_block
);
block_load_direct_striped<block_size>(flat_id,
input + block_offset,
values,
valid_in_last_block);

output_value = values[0];
ROCPRIM_UNROLL
Expand All @@ -136,34 +126,26 @@ void block_reduce_kernel_impl(InputIterator input,
}
else
{
block_load_direct_striped<block_size>(
flat_id,
input + block_offset,
values
);
block_load_direct_striped<block_size>(flat_id, input + block_offset, values);

// load input values into values
block_reduce_type()
.reduce(
values, // input
output_value, // output
reduce_op
);
block_reduce_type().reduce(values, // input
output_value, // output
reduce_op);
}

// Save value into output
if(flat_id == 0)
{
output[flat_block_id] = input_size == 0
? static_cast<result_type>(initial_value)
: reduce_with_initial<WithInitialValue>(
output_value,
static_cast<result_type>(initial_value),
reduce_op
);
output[flat_block_id]
= input_size == 0
? static_cast<result_type>(initial_value)
: reduce_with_initial<WithInitialValue>(output_value,
static_cast<result_type>(initial_value),
reduce_op);
}
}
} // end of detail namespace
} // namespace detail

END_ROCPRIM_NAMESPACE

Expand Down
Loading

0 comments on commit f37b344

Please sign in to comment.