Skip to content

Commit

Permalink
Fix adapt calls, rm device-side structs
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 6, 2025
1 parent 951f4f4 commit 041ca3e
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 67 deletions.
41 changes: 41 additions & 0 deletions ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import Adapt
import CUDA
import ClimaCore: Grids

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
device::ClimaComms.CPUSingleThreaded,
) = ClimaComms.CUDADevice()

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
device::ClimaComms.CPUMultiThreaded,
) = ClimaComms.CUDADevice()

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
context::ClimaComms.SingletonCommsContext,
) = ClimaComms.context(Adapt.adapt(to, context.device))

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.SpectralElementGrid2D,
) = Grids.DeviceSpectralElementGrid2D(
Adapt.adapt(to, grid.topology),
Adapt.adapt(to, grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.local_geometry),
Adapt.adapt(to, grid.local_dss_weights),
Adapt.adapt(to, grid.internal_surface_geometry),
Adapt.adapt(to, grid.boundary_surface_geometries),
Adapt.adapt(to, grid.enable_bubble),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
topology::Topologies.IntervalTopology,
) = IntervalTopology(
Adapt.adapt_structure(to, topology.context),
Adapt.adapt_structure(to, topology.mesh),
Adapt.adapt_structure(to, topology.boundaries),
)
1 change: 0 additions & 1 deletion src/Grids/Grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import ..DataLayouts,
..Domains, ..Meshes, ..Topologies, ..Geometry, ..Quadratures
import ..Utilities: PlusHalf, half, Cache
import ..slab, ..column, ..level
import ..DeviceSideDevice, ..DeviceSideContext

using StaticArrays

Expand Down
39 changes: 26 additions & 13 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -587,25 +587,38 @@ quadrature_style(grid::AbstractSpectralElementGrid) = grid.quadrature_style
local_dss_weights(grid::SpectralElementGrid1D) = grid.dss_weights
local_dss_weights(grid::SpectralElementGrid2D) = grid.local_dss_weights

## GPU compatibility
struct DeviceSpectralElementGrid2D{Q, GG, LG} <: AbstractSpectralElementGrid
## Same as SpectralElementGrid2D, but immutable (for GPU compatibility)
struct DeviceSpectralElementGrid2D{
T,
Q,
GG <: Geometry.AbstractGlobalGeometry,
LG,
D,
IS,
BS,
} <: AbstractSpectralElementGrid
topology::T
quadrature_style::Q
global_geometry::GG
local_geometry::LG
local_dss_weights::D
internal_surface_geometry::IS
boundary_surface_geometries::BS
enable_bubble::Bool
end

ClimaComms.context(grid::DeviceSpectralElementGrid2D) = DeviceSideContext()
ClimaComms.device(grid::DeviceSpectralElementGrid2D) = DeviceSideDevice()

Adapt.adapt_structure(to, grid::SpectralElementGrid2D) =
DeviceSpectralElementGrid2D(
Adapt.adapt(to, grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.local_geometry),
)
local_dss_weights(grid::DeviceSpectralElementGrid2D) = grid.local_dss_weights
local_geometry_type(
::Type{DeviceSpectralElementGrid2D{T, Q, GG, LG, D, IS, BS}},
) where {T, Q, GG, LG, D, IS, BS} = eltype(LG) # calls eltype from DataLayouts

# for both CPU/GPU SpectralElementGrid2D
const SpectralElementGrids2D{T, Q, GG, LG, D, IS, BS} = Union{
SpectralElementGrid2D{T, Q, GG, LG, D, IS, BS},
DeviceSpectralElementGrid2D{T, Q, GG, LG, D, IS, BS},
}
## aliases
const RectilinearSpectralElementGrid2D =
SpectralElementGrid2D{<:Topologies.RectilinearTopology2D}
SpectralElementGrids2D{<:Topologies.RectilinearTopology2D}
const CubedSphereSpectralElementGrid2D =
SpectralElementGrid2D{<:Topologies.CubedSphereTopology2D}
SpectralElementGrids2D{<:Topologies.CubedSphereTopology2D}
2 changes: 0 additions & 2 deletions src/Spaces/Spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import ..DataLayouts,
import ..Domains: z_max, z_min
import ..Meshes: n_elements_per_panel_direction

import ..DeviceSideDevice, ..DeviceSideContext

import ..Grids:
Staggering,
CellFace,
Expand Down
6 changes: 4 additions & 2 deletions src/Spaces/pointspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ function PointSpace(
end


Adapt.adapt_structure(to, space::PointSpace) =
PointSpace(DeviceSideContext(), Adapt.adapt(to, local_geometry_data(space)))
Adapt.adapt_structure(to, space::PointSpace) = PointSpace(
ClimaComms.CUDADevice(),
Adapt.adapt(to, local_geometry_data(space)),
)

function PointSpace(
context::ClimaComms.AbstractCommsContext,
Expand Down
2 changes: 0 additions & 2 deletions src/Topologies/Topologies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import ..DataLayouts
import ..DataLayouts: slab_index
import ..slab, ..column, ..level

import ..DeviceSideDevice, ..DeviceSideContext

"""
AbstractTopology
Expand Down
10 changes: 0 additions & 10 deletions src/Topologies/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,6 @@ struct IntervalTopology{
boundaries::B
end

## gpu
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
boundaries::B
end
Adapt.adapt_structure(to, topology::IntervalTopology) =
DeviceIntervalTopology(topology.boundaries)

ClimaComms.context(topology::DeviceIntervalTopology) = DeviceSideContext()
ClimaComms.device(topology::DeviceIntervalTopology) = DeviceSideDevice()

ClimaComms.device(topology::IntervalTopology) = topology.context.device
ClimaComms.array_type(topology::IntervalTopology) =
ClimaComms.array_type(topology.context.device)
Expand Down
28 changes: 0 additions & 28 deletions src/devices.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1 @@
import ClimaComms

# Sometimes, it is convenient to know whether something is running on a device or on its
# host. We define new AbstractDevices and AbstractCommsContext to identify the case of a
# function running on the device with data on the device.

"""
DeviceSideDevice()
This device represents data defined on the device side of an accelerator.
The most common example is data defined on a GPU. DeviceSideDevice() is used for
operations within the accelerator.
"""
struct DeviceSideDevice <: ClimaComms.AbstractDevice end


"""
DeviceSideContext()
Context associated to data defined on the device side of an accelerator.
The most common example is data defined on a GPU. DeviceSideContext() is used for
operations within the accelerator.
"""
struct DeviceSideContext <: ClimaComms.AbstractCommsContext end

ClimaComms.context(::DeviceSideDevice) = DeviceSideContext()
ClimaComms.device(::DeviceSideContext) = DeviceSideDevice()
19 changes: 10 additions & 9 deletions test/Spaces/unit_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import ClimaCore:
Quadratures,
Fields,
DataLayouts,
Geometry,
DeviceSideContext,
DeviceSideDevice
Geometry

import ClimaCore.DataLayouts: IJFH, VF, slab_index

Expand Down Expand Up @@ -199,12 +197,14 @@ end

if on_gpu
adapted_space = adapt(c_space)(c_space)
@test ClimaComms.context(adapted_space) == DeviceSideContext()
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
@test ClimaComms.context(adapted_space) ==
ClimaComms.context(ClimaComms.CUDADevice())
@test ClimaComms.device(adapted_space) == ClimaComms.CUDADevice()

adapted_hspace = adapt(hspace)(hspace)
@test ClimaComms.context(adapted_hspace) == DeviceSideContext()
@test ClimaComms.device(adapted_hspace) == DeviceSideDevice()
@test ClimaComms.context(adapted_hspace) ==
ClimaComms.context(ClimaComms.CUDADevice())
@test ClimaComms.device(adapted_hspace) == ClimaComms.CUDADevice()
end

end
Expand Down Expand Up @@ -246,8 +246,9 @@ end

if on_gpu
adapted_space = adapt(space)(space)
@test ClimaComms.context(adapted_space) == DeviceSideContext()
@test ClimaComms.device(adapted_space) == DeviceSideDevice()
@test ClimaComms.context(adapted_space) ==
ClimaComms.context(ClimaComms.CUDADevice())
@test ClimaComms.device(adapted_space) == ClimaComms.CUDADevice()
end

for i in 1:4, j in 1:4
Expand Down

0 comments on commit 041ca3e

Please sign in to comment.