Skip to content

Commit

Permalink
Abstract away use of ClimaComms mpicomm
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 6, 2025
1 parent 81a8fa1 commit 951f4f4
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ include("CommonGrids/CommonGrids.jl")
include("CommonSpaces/CommonSpaces.jl")

include("deprecated.jl")
include("compat.jl")

end # module
1 change: 1 addition & 0 deletions src/InputOutput/InputOutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import ..Geometry,
..Hypsography
import ..VERSION
import ..Utilities: PlusHalf, half
import ..climacomms_mpicomm

include("writers.jl")
include("readers.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/InputOutput/readers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ function HDF5Reader(
if context isa ClimaComms.SingletonCommsContext
file = h5open(filename, "r")
else
file = h5open(filename, "r", context.mpicomm)
file = h5open(filename, "r", climacomms_mpicomm(context))
end
if !haskey(attrs(file), "ClimaCore version")
error("Not a ClimaCore HDF5 file")
Expand Down
2 changes: 1 addition & 1 deletion src/InputOutput/writers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function HDF5Writer(
if context isa ClimaComms.SingletonCommsContext
file = h5open(filename, mode)
else
file = h5open(filename, mode, context.mpicomm)
file = h5open(filename, mode, climacomms_mpicomm(context))
end
# Add an attribute to the file if it doesn't already exist
if haskey(attributes(file), "ClimaCore version")
Expand Down
3 changes: 3 additions & 0 deletions src/compat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import ClimaComms

climacomms_mpicomm(ctx::ClimaComms.MPICommsContext) = ctx.mpicomm
5 changes: 3 additions & 2 deletions test/Operators/spectralelement/benchmark_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import ClimaCore.Geometry as Geometry
import ClimaCore.Quadratures as Quadratures
import ClimaCore.Fields as Fields
import ClimaCore as CC
import ClimaCore: climacomms_mpicomm
import ClimaComms
ClimaComms.@import_required_backends

Expand Down Expand Up @@ -160,9 +161,9 @@ function setup_kernel_args(ARGS::Vector{String} = ARGS)
device isa ClimaComms.CUDADevice
# assign GPUs based on local rank
local_comm = ClimaComms.MPI.Comm_split_type(
context.mpicomm,
climacomms_mpicomm(context),
ClimaComms.MPI.COMM_TYPE_SHARED,
ClimaComms.MPI.Comm_rank(context.mpicomm),
ClimaComms.MPI.Comm_rank(climacomms_mpicomm(context)),
)
CUDA.device!(
ClimaComms.MPI.Comm_rank(local_comm) % length(CUDA.devices()),
Expand Down

0 comments on commit 951f4f4

Please sign in to comment.