From 463aecd43ce3296656728b2d3a769ee9882b21c0 Mon Sep 17 00:00:00 2001 From: Michael Schellenberger Costa Date: Tue, 30 Apr 2024 19:10:23 +0200 Subject: [PATCH] Avoid copying output iterators in `thrust::copy_if` --- thrust/thrust/system/cuda/detail/copy_if.h | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/thrust/thrust/system/cuda/detail/copy_if.h b/thrust/thrust/system/cuda/detail/copy_if.h index b36285b2f2a..7706a6ac2a4 100644 --- a/thrust/thrust/system/cuda/detail/copy_if.h +++ b/thrust/thrust/system/cuda/detail/copy_if.h @@ -45,6 +45,7 @@ # include # include +# include # include # include # include @@ -95,10 +96,9 @@ struct DispatchCopyIf size_t& temp_storage_bytes, InputIt first, StencilIt stencil, - OutputIt output, + OutputIt& output, Predicate predicate, - OffsetT num_items, - OutputIt& output_end) + OffsetT num_items) { using num_selected_out_it_t = OffsetT*; using equality_op_t = cub::NullType; @@ -147,7 +147,6 @@ struct DispatchCopyIf // Return for empty problems if (num_items == 0) { - output_end = output; return status; } @@ -180,8 +179,7 @@ struct DispatchCopyIf status = cuda_cub::synchronize(policy); CUDA_CUB_RET_IF_FAIL(status); OffsetT num_selected = get_value(policy, d_num_selected_out); - - output_end = output + num_selected; + thrust::advance(output, num_selected); return status; } }; @@ -197,8 +195,7 @@ THRUST_RUNTIME_FUNCTION OutputIt copy_if( { using size_type = typename iterator_traits::difference_type; - size_type num_items = static_cast(thrust::distance(first, last)); - OutputIt output_end{}; + size_type num_items = static_cast(thrust::distance(first, last)); cudaError_t status = cudaSuccess; size_t temp_storage_bytes = 0; @@ -214,7 +211,7 @@ THRUST_RUNTIME_FUNCTION OutputIt copy_if( dispatch32_t::dispatch, dispatch64_t::dispatch, num_items, - (policy, nullptr, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed, output_end)); + (policy, nullptr, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed)); cuda_cub::throw_on_error(status, "copy_if failed on 1st step"); // Allocate temporary storage. @@ -227,10 +224,10 @@ THRUST_RUNTIME_FUNCTION OutputIt copy_if( dispatch32_t::dispatch, dispatch64_t::dispatch, num_items, - (policy, temp_storage, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed, output_end)); + (policy, temp_storage, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed)); cuda_cub::throw_on_error(status, "copy_if failed on 2nd step"); - return output_end; + return output; } } // namespace detail