Skip to content

Commit

Permalink
Improve GPU performance by changing the offset computation of the `Fu…
Browse files Browse the repository at this point in the history
…llGridCellList` (#50)

* Subtract `min_corner` in `cell_coords` instead of subtracting `min_cell` in `getindex`

* Fix copy and periodicity

* Fix periodicity

* Fix tests
  • Loading branch information
efaulhaber authored Jul 5, 2024
1 parent 35ed828 commit c5a13fe
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 29 deletions.
45 changes: 27 additions & 18 deletions src/cell_lists/full_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ See [`copy_neighborhood_search`](@ref) for more details.
struct FullGridCellList{C, LI, MC} <: AbstractCellList
cells :: C
linear_indices :: LI
min_cell :: MC
min_corner :: MC
end

function supported_update_strategies(::FullGridCellList{<:DynamicVectorOfVectors})
Expand All @@ -44,32 +44,30 @@ supported_update_strategies(::FullGridCellList) = (SemiParallelUpdate, SerialUpd
function FullGridCellList(; min_corner, max_corner, search_radius = 0.0,
periodicity = false, backend = DynamicVectorOfVectors{Int32},
max_points_per_cell = 100)
# Pad domain to avoid 0 in cell indices due to rounding errors.
# We can't just use `eps()`, as one might use lower precision types.
# This padding is safe, and will give us one more layer of cells in the worst case.
min_corner = SVector(Tuple(min_corner .- 1e-3 * search_radius))
max_corner = SVector(Tuple(max_corner .+ 1e-3 * search_radius))

if search_radius < eps()
# Create an empty "template" cell list to be used with `copy_cell_list`
cells = construct_backend(backend, 0, 0)
linear_indices = nothing

# Misuse `min_cell` to store min and max corner for copying
min_cell = (min_corner, max_corner)
# Misuse `min_corner` to store min and max corner for copying
min_corner = (min_corner, max_corner)
else
if periodicity
# Subtract `min_corner` because that's how the grid NHS works with periodicity
max_corner = max_corner .- min_corner
min_corner = min_corner .- min_corner
end

# Note that we don't shift everything so that the first cell starts at `min_corner`.
# The first cell is the cell containing `min_corner`, so we need to add one layer
# in order for `max_corner` to be inside a cell.
n_cells_per_dimension = ceil.(Int, (max_corner .- min_corner) ./ search_radius) .+ 1
linear_indices = LinearIndices(Tuple(n_cells_per_dimension))
min_cell = Tuple(floor_to_int.(min_corner ./ search_radius))

cells = construct_backend(backend, n_cells_per_dimension, max_points_per_cell)
end

return FullGridCellList{typeof(cells), typeof(linear_indices),
typeof(min_cell)}(cells, linear_indices, min_cell)
return FullGridCellList(cells, linear_indices, min_corner)
end

function construct_backend(::Type{Vector{Vector{T}}}, size, max_points_per_cell) where {T}
Expand All @@ -94,6 +92,17 @@ function construct_backend(::Type{DynamicVectorOfVectors{T1, T2, T3, T4}}, size,
return construct_backend(DynamicVectorOfVectors{T1}, size, max_points_per_cell)
end

@inline function cell_coords(coords, periodic_box::Nothing, cell_list::FullGridCellList,
cell_size)
(; min_corner) = cell_list

# Subtract `min_corner` to offset coordinates so that the min corner of the grid
# corresponds to the (1, 1, 1) cell.
# Note that we use `min_corner == periodic_box.min_corner`, so we don't have to handle
# periodic boxes differently, as they also use 1-based indexing.
return Tuple(floor_to_int.((coords .- min_corner) ./ cell_size)) .+ 1
end

function Base.empty!(cell_list::FullGridCellList)
(; cells) = cell_list

Expand All @@ -107,7 +116,7 @@ end

function Base.empty!(cell_list::FullGridCellList{Nothing})
# This is an empty "template" cell list to be used with `copy_cell_list`
throw(UndefRefError("`search_radius` is not defined for this cell list"))
error("`search_radius` is not defined for this cell list")
end

function push_cell!(cell_list::FullGridCellList, cell, particle)
Expand All @@ -121,7 +130,7 @@ end

function push_cell!(cell_list::FullGridCellList{Nothing}, cell, particle)
# This is an empty "template" cell list to be used with `copy_cell_list`
throw(UndefRefError("`search_radius` is not defined for this cell list"))
error("`search_radius` is not defined for this cell list")
end

@inline function push_cell_atomic!(cell_list::FullGridCellList, cell, particle)
Expand All @@ -146,13 +155,13 @@ end

function each_cell_index(cell_list::FullGridCellList{Nothing})
# This is an empty "template" cell list to be used with `copy_cell_list`
throw(UndefRefError("`search_radius` is not defined for this cell list"))
error("`search_radius` is not defined for this cell list")
end

@inline function cell_index(cell_list::FullGridCellList, cell::Tuple)
(; linear_indices, min_cell) = cell_list
(; linear_indices) = cell_list

return linear_indices[(cell .- min_cell .+ 1)...]
return linear_indices[cell...]
end

@inline cell_index(::FullGridCellList, cell::Integer) = cell
Expand All @@ -171,7 +180,7 @@ end

function copy_cell_list(cell_list::FullGridCellList, search_radius, periodic_box)
# Misuse `min_cell` to store min and max corner for copying
min_corner, max_corner = cell_list.min_cell
min_corner, max_corner = cell_list.min_corner

return FullGridCellList(; min_corner, max_corner, search_radius,
periodicity = !isnothing(periodic_box),
Expand Down
18 changes: 10 additions & 8 deletions src/nhs_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -400,28 +400,30 @@ end

@inline periodic_cell_index(cell_index, ::Nothing, n_cells) = cell_index

@inline function periodic_cell_index(cell_index, periodic_box, n_cells)
return rem.(cell_index, n_cells, RoundDown)
@inline function periodic_cell_index(cell_index, ::PeriodicBox, n_cells)
# 1-based modulo
return rem.(cell_index .- 1, n_cells, RoundDown) .+ 1
end

@inline function cell_coords(coords, neighborhood_search)
(; periodic_box, cell_size) = neighborhood_search
(; periodic_box, cell_list, cell_size) = neighborhood_search

return cell_coords(coords, periodic_box, cell_size)
return cell_coords(coords, periodic_box, cell_list, cell_size)
end

@inline function cell_coords(coords, periodic_box::Nothing, cell_size)
@inline function cell_coords(coords, periodic_box::Nothing, cell_list, cell_size)
return Tuple(floor_to_int.(coords ./ cell_size))
end

@inline function cell_coords(coords, periodic_box, cell_size)
@inline function cell_coords(coords, periodic_box::PeriodicBox, cell_list, cell_size)
# Subtract `min_corner` to offset coordinates so that the min corner of the periodic
# box corresponds to the (0, 0) cell of the NHS.
# box corresponds to the (0, 0, 0) cell of the NHS.
# This way, there are no partial cells in the domain if the domain size is an integer
# multiple of the cell size (which is required, see the constructor).
offset_coords = periodic_coords(coords, periodic_box) .- periodic_box.min_corner

return Tuple(floor_to_int.(offset_coords ./ cell_size))
# Add one for 1-based indexing. The min corner will be the (1, 1, 1)-cell.
return Tuple(floor_to_int.(offset_coords ./ cell_size)) .+ 1
end

function copy_neighborhood_search(nhs::GridNeighborhoodSearch, search_radius, n_points;
Expand Down
19 changes: 16 additions & 3 deletions test/nhs_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,25 @@
coords2 = [NaN, 0]
coords3 = [typemax(Int) + 1.0, -typemax(Int) - 1.0]

@test PointNeighbors.cell_coords(coords1, nothing, (1.0, 1.0)) ==
@test PointNeighbors.cell_coords(coords1, nothing, nothing, (1.0, 1.0)) ==
(typemax(Int), typemin(Int))
@test PointNeighbors.cell_coords(coords2, nothing, (1.0, 1.0)) ==
@test PointNeighbors.cell_coords(coords2, nothing, nothing, (1.0, 1.0)) ==
(typemax(Int), 0)
@test PointNeighbors.cell_coords(coords3, nothing, (1.0, 1.0)) ==
@test PointNeighbors.cell_coords(coords3, nothing, nothing, (1.0, 1.0)) ==
(typemax(Int), typemin(Int))

# The full grid cell list adds one to the coordinates to avoid zero-indexing.
# This corner case is not relevant, as `typemax` coordinates will always be out of
# bounds for the finite domain of the full grid cell list.
cell_list = FullGridCellList(min_corner = (0.0, 0.0), max_corner = (1.0, 1.0),
search_radius = 1.0)

@test PointNeighbors.cell_coords(coords1, nothing, cell_list, (1.0, 1.0)) ==
(typemax(Int), typemin(Int)) .+ 1
@test PointNeighbors.cell_coords(coords2, nothing, cell_list, (1.0, 1.0)) ==
(typemax(Int), 0) .+ 1
@test PointNeighbors.cell_coords(coords3, nothing, cell_list, (1.0, 1.0)) ==
(typemax(Int), typemin(Int)) .+ 1
end

@testset "Rectangular Point Cloud 2D" begin
Expand Down

0 comments on commit c5a13fe

Please sign in to comment.