From 5b025d491e9c5572421e8b79e71807d4a4aae5f3 Mon Sep 17 00:00:00 2001 From: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:26:00 +0100 Subject: [PATCH] Fix localmem kernel --- src/nhs_grid.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/nhs_grid.jl b/src/nhs_grid.jl index dc9570d..1df3b8a 100644 --- a/src/nhs_grid.jl +++ b/src/nhs_grid.jl @@ -414,14 +414,16 @@ end ndrange = max_particles_per_cell * length(nonempty_cells) n_gpus = length(CUDA.devices()) - ndrange_local = [div(ndrange, n_gpus) for _ in 1:n_gpus] - ndrange_local[end] += ndrange % n_gpus + cells_split = Iterators.partition(nonempty_cells, ceil(Int, length(nonempty_cells) / n_gpus)) + @assert length(cells_split) == n_gpus kernel = foreach_neighbor_localmem(backend, (max_particles_per_cell,)) - @sync for i in 1:n_gpus + @sync for (i, nonempty_cells_) in enumerate(cells_split) Threads.@spawn begin CUDA.device!(i - 1) - kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells, Val(max_particles_per_cell), search_radius; ndrange = ndrange_local[i]) + kernel(f, system_coords, neighbor_coords, neighborhood_search, nonempty_cells_, + Val(max_particles_per_cell), search_radius; + ndrange = length(nonempty_cells_) * max_particles_per_cell) KernelAbstractions.synchronize(backend) end end