diff --git a/rocprim/include/rocprim/device/device_find_first_of.hpp b/rocprim/include/rocprim/device/device_find_first_of.hpp index df5c69ed2..203547004 100644 --- a/rocprim/include/rocprim/device/device_find_first_of.hpp +++ b/rocprim/include/rocprim/device/device_find_first_of.hpp @@ -58,125 +58,128 @@ namespace detail } \ while(0) -template -ROCPRIM_KERNEL -void init_find_first_of_kernel(T* output, T size, ordered_block_id ordered_bid) +template +struct find_first_of_impl_kernels { - *output = size; - ordered_bid.reset(); -} + template + static ROCPRIM_KERNEL + void init_find_first_of_kernel(T* output, T size, ordered_block_id ordered_bid) + { + *output = size; + ordered_bid.reset(); + } -template -ROCPRIM_KERNEL + static ROCPRIM_KERNEL #ifndef DOXYGEN_DOCUMENTATION_BUILD -__launch_bounds__(device_params().kernel_config.block_size) + __launch_bounds__(device_params().kernel_config.block_size) #endif -void find_first_of_kernel(InputIterator1 input, - InputIterator2 keys, - size_t* output, - size_t size, - size_t keys_size, - ordered_block_id ordered_bid, - BinaryFunction compare_function) -{ - constexpr find_first_of_config_params params = device_params(); + void find_first_of_kernel(InputIterator1 input, + InputIterator2 keys, + size_t* output, + size_t size, + size_t keys_size, + ordered_block_id ordered_bid, + BinaryFunction compare_function) + { + constexpr find_first_of_config_params params = device_params(); - constexpr unsigned int block_size = params.kernel_config.block_size; - constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; - constexpr unsigned int items_per_block = block_size * items_per_thread; - constexpr unsigned int identity = std::numeric_limits::max(); + constexpr unsigned int block_size = params.kernel_config.block_size; + constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + constexpr unsigned int identity = std::numeric_limits::max(); - using type = typename std::iterator_traits::value_type; - using key_type = typename std::iterator_traits::value_type; + using type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; - const unsigned int thread_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int thread_id = ::rocprim::detail::block_thread_id<0>(); - ROCPRIM_SHARED_MEMORY struct - { - unsigned int block_first_index; - size_t global_first_index; + ROCPRIM_SHARED_MEMORY struct + { + unsigned int block_first_index; + size_t global_first_index; - typename decltype(ordered_bid)::storage_type ordered_bid; - } storage; + typename decltype(ordered_bid)::storage_type ordered_bid; + } storage; - if(thread_id == 0) - { - storage.block_first_index = identity; - } - syncthreads(); - - while(true) - { if(thread_id == 0) { - storage.global_first_index = atomic_load(output); + storage.block_first_index = identity; } - const size_t block_id = ordered_bid.get(thread_id, storage.ordered_bid); - const size_t block_offset = block_id * items_per_block; - // ordered_bid.get() calls syncthreads(), it is safe to read global_first_index + syncthreads(); - // Exit if all input has been processed or one of previous blocks has found a match - if(block_offset >= storage.global_first_index) + while(true) { - break; - } + if(thread_id == 0) + { + storage.global_first_index = atomic_load(output); + } + const size_t block_id = ordered_bid.get(thread_id, storage.ordered_bid); + const size_t block_offset = block_id * items_per_block; + // ordered_bid.get() calls syncthreads(), it is safe to read global_first_index - unsigned int thread_first_index = identity; + // Exit if all input has been processed or one of previous blocks has found a match + if(block_offset >= storage.global_first_index) + { + break; + } - if(block_offset + items_per_block <= size) - { - type items[items_per_thread]; - block_load_direct_striped(thread_id, input + block_offset, items); - for(size_t key_index = 0; key_index < keys_size; ++key_index) + unsigned int thread_first_index = identity; + + if(block_offset + items_per_block <= size) { - const key_type key = keys[key_index]; - ROCPRIM_UNROLL - for(unsigned int i = 0; i < items_per_thread; ++i) + type items[items_per_thread]; + block_load_direct_striped(thread_id, input + block_offset, items); + for(size_t key_index = 0; key_index < keys_size; ++key_index) { - if(compare_function(key, items[i])) - { - thread_first_index = min(thread_first_index, i); - } + const key_type key = keys[key_index]; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; ++i) + { + if(compare_function(key, items[i])) + { + thread_first_index = min(thread_first_index, i); + } + } } } - } - else - { - const unsigned int valid = size - block_offset; - - type items[items_per_thread]; - block_load_direct_striped(thread_id, input + block_offset, items, valid); - for(size_t key_index = 0; key_index < keys_size; ++key_index) + else { - const key_type key = keys[key_index]; - ROCPRIM_UNROLL - for(unsigned int i = 0; i < items_per_thread; ++i) + const unsigned int valid = size - block_offset; + + type items[items_per_thread]; + block_load_direct_striped(thread_id, input + block_offset, items, valid); + for(size_t key_index = 0; key_index < keys_size; ++key_index) { - if(i * block_size + thread_id < valid && compare_function(key, items[i])) - { - thread_first_index = min(thread_first_index, i); - } + const key_type key = keys[key_index]; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < items_per_thread; ++i) + { + if(i * block_size + thread_id < valid && compare_function(key, items[i])) + { + thread_first_index = min(thread_first_index, i); + } + } } } - } - if(thread_first_index != identity) - { - // This happens to some blocks rarely so it is not beneficial to avoid atomic conflicts - // with block_reduce which needs to be computed even if no threads have a match. - atomic_min(&storage.block_first_index, thread_first_index * block_size + thread_id); - } - syncthreads(); - if(storage.block_first_index != identity) - { - if(thread_id == 0) + if(thread_first_index != identity) { - atomic_min(output, block_offset + storage.block_first_index); + // This happens to some blocks rarely so it is not beneficial to avoid atomic conflicts + // with block_reduce which needs to be computed even if no threads have a match. + atomic_min(&storage.block_first_index, thread_first_index * block_size + thread_id); + } + syncthreads(); + if(storage.block_first_index != identity) + { + if(thread_id == 0) + { + atomic_min(output, block_offset + storage.block_first_index); + } + break; } - break; } } -} +}; template::value_type; using config = wrapped_find_first_of_config; + using find_first_of_kernels + = find_first_of_impl_kernels; target_arch target_arch; hipError_t result = host_target_arch(stream, target_arch); @@ -238,12 +243,14 @@ hipError_t find_first_of_impl(void* temporary_storage, { start = std::chrono::steady_clock::now(); } - init_find_first_of_kernel<<<1, 1, 0, stream>>>(tmp_output, size, ordered_bid); + find_first_of_kernels::init_find_first_of_kernel<<<1, 1, 0, stream>>>(tmp_output, + size, + ordered_bid); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_find_first_of_kernel", 1, start); if(size > 0 && keys_size > 0) { - auto kernel = find_first_of_kernel; + auto kernel = find_first_of_kernels::find_first_of_kernel; const size_t shared_memory_size = 0;