Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compact syntax for overlapping multiple DSS calls #1573

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,51 @@ Spaces.weighted_dss_internal!(field::Field, dss_buffer) =
Spaces.weighted_dss_ghost!(field::Field, dss_buffer) =
Spaces.weighted_dss_ghost!(field_values(field), axes(field), dss_buffer)

"""
Spaces.weighted_dss!(field1 => ghost_buffer1, field2 => ghost_buffer2, ...)

Call [`Spaces.weighted_dss!`](@ref) on multiple fields at once, overlapping
communication as much as possible.
"""
function Spaces.weighted_dss!(
(field1, dss_buffer1)::Pair,
field_buffer_pairs::Pair...,
)
device = ClimaComms.device(axes(field1))
Spaces.weighted_dss_prepare!(
field_values(field1),
axes(field1),
dss_buffer1,
)
for (field, dss_buffer) in field_buffer_pairs
Spaces.weighted_dss_prepare!(
field_values(field),
axes(field),
dss_buffer,
)
end

if device isa ClimaComms.CUDADevice
CUDA.synchronize(; blocking = true)
end

ClimaComms.start(dss_buffer1.graph_context)
for (field, dss_buffer) in field_buffer_pairs
ClimaComms.start(dss_buffer.graph_context)
end

Spaces.weighted_dss_internal!(field1, dss_buffer1)
for (field, dss_buffer) in field_buffer_pairs
Spaces.weighted_dss_internal!(field, dss_buffer)
end

Spaces.weighted_dss_ghost!(field1, dss_buffer1)
for (field, dss_buffer) in field_buffer_pairs
Spaces.weighted_dss_ghost!(field, dss_buffer)
end

return nothing
end

# Add definitions for backward compatibility
Spaces.weighted_dss2!(
Expand Down
83 changes: 43 additions & 40 deletions src/Spaces/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,37 @@ function weighted_dss!(
weighted_dss_ghost!(data, space, dss_buffer)
end

function weighted_dss_prepare!(
data::Union{DataLayouts.IJFH, DataLayouts.VIJFH},
space::Union{
Spaces.SpectralElementSpace2D,
Spaces.ExtrudedFiniteDifferenceSpace,
},
dss_buffer::DSSBuffer,
)
assert_same_eltype(data, dss_buffer)
length(parent(data)) == 0 && return nothing
device = ClimaComms.device(topology(space))
hspace = horizontal_space(space)
dss_transform!(
device,
dss_buffer,
data,
local_geometry_data(space),
local_dss_weights(hspace),
Spaces.perimeter(hspace),
dss_buffer.perimeter_elems,
)
dss_local_ghost!(
device,
dss_buffer.perimeter_data,
Spaces.perimeter(hspace),
topology(hspace),
)
fill_send_buffer!(device, dss_buffer; synchronize = false)
return nothing
end


"""
weighted_dss_start!(
Expand Down Expand Up @@ -117,57 +148,29 @@ representative ghost vertices which store result of "ghost local" DSS are loaded

4). Start DSS communication with neighboring processes
"""
weighted_dss_start!(
data::Union{
DataLayouts.IFH,
DataLayouts.VIFH,
DataLayouts.IJFH,
DataLayouts.VIJFH,
},
space::Union{AbstractSpectralElementSpace, ExtrudedFiniteDifferenceSpace},
dss_buffer::Union{DSSBuffer, Nothing},
) = weighted_dss_start!(data, space, horizontal_space(space), dss_buffer)



function weighted_dss_start!(
data::Union{DataLayouts.IJFH, DataLayouts.VIJFH},
space::Union{
Spaces.SpectralElementSpace2D,
Spaces.ExtrudedFiniteDifferenceSpace,
},
hspace::SpectralElementSpace2D,
dss_buffer::DSSBuffer,
)
assert_same_eltype(data, dss_buffer)
length(parent(data)) == 0 && return nothing
device = ClimaComms.device(topology(hspace))
dss_transform!(
device,
dss_buffer,
data,
local_geometry_data(space),
local_dss_weights(hspace),
Spaces.perimeter(hspace),
dss_buffer.perimeter_elems,
)
dss_local_ghost!(
device,
dss_buffer.perimeter_data,
Spaces.perimeter(hspace),
topology(hspace),
)
fill_send_buffer!(device, dss_buffer)
device = ClimaComms.device(topology(space))
weighted_dss_prepare!(data, space, dss_buffer)
if device isa ClimaComms.CUDADevice
CUDA.synchronize(; blocking = true)
end
ClimaComms.start(dss_buffer.graph_context)
return nothing
end

weighted_dss_start!(
data,
space,
hspace::SpectralElementSpace1D,
dss_buffer::Nothing,
) = nothing
weighted_dss_start!(data, space, dss_buffer::Nothing) = nothing

# TODO: deprecate
weighted_dss_start!(data, space, hspace, dss_buffer) =
weighted_dss_start!(data, space, dss_buffer)


"""
weighted_dss_internal!(
Expand Down Expand Up @@ -299,9 +302,9 @@ function weighted_dss_ghost!(
dss_buffer::DSSBuffer,
)
assert_same_eltype(data, dss_buffer)
ClimaComms.finish(dss_buffer.graph_context)
length(parent(data)) == 0 && return data
device = ClimaComms.device(topology(hspace))
ClimaComms.finish(dss_buffer.graph_context)
load_from_recv_buffer!(device, dss_buffer)
dss_ghost!(
device,
Expand Down
14 changes: 9 additions & 5 deletions src/Topologies/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ $(DocStringExtensions.FIELDS)
struct DSSBuffer{S, G, D, A, B, VI}
"ClimaComms graph context for communication"
graph_context::G
"Array for storing perimeter data"
"""
Perimeter `DataLayout` object: typically a `VIFH{TT,Np}`, where `TT` is the
transformed type, and `Np` is the length of the perimeter
"""
perimeter_data::D
"send buffer"
"send buffer `AbstractVector{FT}`"
send_data::A
"recv buffer"
"recv buffer `AbstractVector{FT}`"
recv_data::A
"indexing array for loading send buffer from `perimeter_data`"
send_buf_idx::B
Expand Down Expand Up @@ -730,7 +733,7 @@ function dss_ghost!(
end

"""
fill_send_buffer!(::ClimaComms.AbstractCPUDevice, dss_buffer::DSSBuffer)
fill_send_buffer!(::ClimaComms.AbstractCPUDevice, dss_buffer::DSSBuffer; synchronize=true)

Loads the send buffer from `perimeter_data`. For unique ghost vertices, only data from the
representative vertices which store result of "ghost local" DSS are loaded.
Expand All @@ -739,7 +742,8 @@ Part of [`ClimaCore.Spaces.weighted_dss!`](@ref).
"""
function fill_send_buffer!(
::ClimaComms.AbstractCPUDevice,
dss_buffer::DSSBuffer,
dss_buffer::DSSBuffer;
synchronize = true,
)
(; perimeter_data, send_buf_idx, send_data) = dss_buffer
(Np, _, _, Nv, nelems) = size(perimeter_data)
Expand Down
10 changes: 8 additions & 2 deletions src/Topologies/dss_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,11 @@ function dss_local_ghost_kernel!(
return nothing
end

function fill_send_buffer!(::ClimaComms.CUDADevice, dss_buffer::DSSBuffer)
function fill_send_buffer!(
::ClimaComms.CUDADevice,
dss_buffer::DSSBuffer;
synchronize = true,
)
(; perimeter_data, send_buf_idx, send_data) = dss_buffer
pperimeter_data = parent(perimeter_data)
(nlevels, nperimeter, nfid, nelems) = size(pperimeter_data)
Expand All @@ -432,7 +436,9 @@ function fill_send_buffer!(::ClimaComms.CUDADevice, dss_buffer::DSSBuffer)
send_buf_idx,
pperimeter_data,
)
CUDA.synchronize(; blocking = true) # CUDA MPI uses a separate stream. This will synchronize across streams
if synchronize
CUDA.synchronize(; blocking = true) # CUDA MPI uses a separate stream. This will synchronize across streams
end
end
return nothing
end
Expand Down
4 changes: 2 additions & 2 deletions test/Spaces/ddss1_cs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ end
dss_buffer12 = Spaces.create_dss_buffer(y12)
dss_buffer12_cpu = Spaces.create_dss_buffer(y12_cpu)
# ensure physical velocity is continous across SE boundary for initial state
Spaces.weighted_dss!(y12, dss_buffer12)
Spaces.weighted_dss!(y12_cpu, dss_buffer12_cpu)
Spaces.weighted_dss!(y12 => dss_buffer12)
Spaces.weighted_dss!(y12_cpu => dss_buffer12_cpu)

yinit12 = copy(y12)
yinit12_cpu = copy(y12_cpu)
Expand Down
Loading