Skip to content

Commit

Permalink
Fix adapt dispatch for KernelAdaptor
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 2, 2025
1 parent b5816bc commit 08ff758
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
30 changes: 30 additions & 0 deletions ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,33 @@ function Adapt.adapt(to::ToCUDA, data::DataLayouts.AbstractData)
Adapt.adapt(CUDA.CuArray, parent(data)),
)
end

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.ExtrudedFiniteDifferenceGrid,
) = Grids.DeviceExtrudedFiniteDifferenceGrid(
Adapt.adapt(to, Grids.vertical_topology(grid)),
Adapt.adapt(to, grid.horizontal_grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.FiniteDifferenceGrid,
) = Grids.DeviceFiniteDifferenceGrid(
Adapt.adapt(to, grid.topology),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.SpectralElementGrid2D,
) = Grids.DeviceSpectralElementGrid2D(
Adapt.adapt(to, grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.local_geometry),
)
9 changes: 0 additions & 9 deletions src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,6 @@ local_geometry_type(
::Type{DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, CLG, FLG}},
) where {VT, Q, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts

Adapt.adapt_structure(to, grid::ExtrudedFiniteDifferenceGrid) =
DeviceExtrudedFiniteDifferenceGrid(
Adapt.adapt(to, vertical_topology(grid)),
Adapt.adapt(to, grid.horizontal_grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

quadrature_style(grid::DeviceExtrudedFiniteDifferenceGrid) =
grid.quadrature_style
vertical_topology(grid::DeviceExtrudedFiniteDifferenceGrid) =
Expand Down
8 changes: 0 additions & 8 deletions src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,6 @@ local_geometry_type(
::Type{DeviceFiniteDifferenceGrid{T, GG, CLG, FLG}},
) where {T, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts

Adapt.adapt_structure(to, grid::FiniteDifferenceGrid) =
DeviceFiniteDifferenceGrid(
Adapt.adapt(to, grid.topology),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

topology(grid::DeviceFiniteDifferenceGrid) = grid.topology
vertical_topology(grid::DeviceFiniteDifferenceGrid) = grid.topology

Expand Down
7 changes: 0 additions & 7 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -620,13 +620,6 @@ end
ClimaComms.context(grid::DeviceSpectralElementGrid2D) = DeviceSideContext()
ClimaComms.device(grid::DeviceSpectralElementGrid2D) = DeviceSideDevice()

Adapt.adapt_structure(to, grid::SpectralElementGrid2D) =
DeviceSpectralElementGrid2D(
Adapt.adapt(to, grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.local_geometry),
)

## aliases
const RectilinearSpectralElementGrid2D =
SpectralElementGrid2D{<:Topologies.RectilinearTopology2D}
Expand Down

0 comments on commit 08ff758

Please sign in to comment.