Skip to content

Commit

Permalink
Avoid copying output iterators in thrust::copy_if
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed May 2, 2024
1 parent ab4cf50 commit 463aecd
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions thrust/thrust/system/cuda/detail/copy_if.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
# include <cub/util_temporary_storage.cuh>
# include <cub/util_type.cuh>

# include <thrust/advance.h>
# include <thrust/detail/alignment.h>
# include <thrust/detail/cstdint.h>
# include <thrust/detail/function.h>
Expand Down Expand Up @@ -95,10 +96,9 @@ struct DispatchCopyIf
size_t& temp_storage_bytes,
InputIt first,
StencilIt stencil,
OutputIt output,
OutputIt& output,
Predicate predicate,
OffsetT num_items,
OutputIt& output_end)
OffsetT num_items)
{
using num_selected_out_it_t = OffsetT*;
using equality_op_t = cub::NullType;
Expand Down Expand Up @@ -147,7 +147,6 @@ struct DispatchCopyIf
// Return for empty problems
if (num_items == 0)
{
output_end = output;
return status;
}

Expand Down Expand Up @@ -180,8 +179,7 @@ struct DispatchCopyIf
status = cuda_cub::synchronize(policy);
CUDA_CUB_RET_IF_FAIL(status);
OffsetT num_selected = get_value(policy, d_num_selected_out);

output_end = output + num_selected;
thrust::advance(output, num_selected);
return status;
}
};
Expand All @@ -197,8 +195,7 @@ THRUST_RUNTIME_FUNCTION OutputIt copy_if(
{
using size_type = typename iterator_traits<InputIt>::difference_type;

size_type num_items = static_cast<size_type>(thrust::distance(first, last));
OutputIt output_end{};
size_type num_items = static_cast<size_type>(thrust::distance(first, last));
cudaError_t status = cudaSuccess;
size_t temp_storage_bytes = 0;

Expand All @@ -214,7 +211,7 @@ THRUST_RUNTIME_FUNCTION OutputIt copy_if(
dispatch32_t::dispatch,
dispatch64_t::dispatch,
num_items,
(policy, nullptr, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed, output_end));
(policy, nullptr, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed));
cuda_cub::throw_on_error(status, "copy_if failed on 1st step");

// Allocate temporary storage.
Expand All @@ -227,10 +224,10 @@ THRUST_RUNTIME_FUNCTION OutputIt copy_if(
dispatch32_t::dispatch,
dispatch64_t::dispatch,
num_items,
(policy, temp_storage, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed, output_end));
(policy, temp_storage, temp_storage_bytes, first, stencil, output, predicate, num_items_fixed));
cuda_cub::throw_on_error(status, "copy_if failed on 2nd step");

return output_end;
return output;
}

} // namespace detail
Expand Down

0 comments on commit 463aecd

Please sign in to comment.