Skip to content

Commit

Permalink
Introduce post-operation callback
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 3, 2025
1 parent 92f3b2b commit b10fbfc
Show file tree
Hide file tree
Showing 25 changed files with 86 additions and 11 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import CUDA
using CUDA
using CUDA: threadIdx, blockIdx, blockDim
import StaticArrays: SVector, SMatrix, SArray
import ClimaCore.DataLayouts: call_post_op_callback, post_op_callback
import ClimaCore.DataLayouts: mapreduce_cuda
import ClimaCore.DataLayouts: ToCUDA
import ClimaCore.DataLayouts: slab, column
Expand Down
6 changes: 6 additions & 0 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ if VERSION ≥ v"1.11.0-beta"
blocks_s = p.blocks,
)
end
call_post_op_callback() && post_op_callback(dest)
return dest
end
else
Expand Down Expand Up @@ -74,6 +75,7 @@ else
)
end
end
call_post_op_callback() && post_op_callback(dest)
return dest
end
end
Expand All @@ -95,6 +97,7 @@ function Base.copyto!(
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
call_post_op_callback() && post_op_callback(dest)
end

# For field-vector operations
Expand All @@ -119,6 +122,7 @@ function DataLayouts.copyto_per_field!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array)
return array
end
function copyto_per_field_kernel!(array, bc, N)
Expand Down Expand Up @@ -154,6 +158,7 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array)
return array
end
function DataLayouts.copyto_per_field_scalar!(
Expand All @@ -177,6 +182,7 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array)
return array
end
function copyto_per_field_kernel_0D!(array, bc, N)
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
)
end
end
call_post_op_callback() && post_op_callback(dest)
return dest
end
8 changes: 6 additions & 2 deletions ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ function mapreduce_cuda(
)
pdata = parent(data)
S = eltype(data)
return DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
call_post_op_callback() && post_op_callback(data_out)
return data_out
end

function mapreduce_cuda(
Expand Down Expand Up @@ -101,7 +103,9 @@ function mapreduce_cuda(
Val(shmemsize),
)
end
return DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
call_post_op_callback() && post_op_callback(data_out)
return data_out
end

function mapreduce_cuda_kernel!(
Expand Down
6 changes: 6 additions & 0 deletions ext/cuda/fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,47 @@ function Base.sum(
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(identity, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() && post_op_callback(localsum[])
return localsum[]
end

function Base.sum(fn, field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(fn, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() && post_op_callback(localsum[])
return localsum[]
end

function Base.maximum(fn, field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(fn, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() && post_op_callback(localmax[])
return localmax[]
end

function Base.maximum(field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(identity, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() && post_op_callback(localmax[])
return localmax[]
end

function Base.minimum(fn, field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(fn, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() && post_op_callback(localmin[])
return localmin[]
end

function Base.minimum(field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(identity, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() && post_op_callback(localmin[])
return localmin[]
end

Expand Down
1 change: 1 addition & 0 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ NVTX.@annotate function multiple_field_solve!(
blocks_s = p.blocks,
always_inline = true,
)
call_post_op_callback() && post_op_callback(x)
end

Base.@propagate_inbounds column_A(A::UniformScaling, i, j, h) = A
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(x)
end

function single_field_solve_kernel!(device, cache, x, A, b, us)
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out)
return out
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ function column_reduce_device!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(output)
end

function column_accumulate_device!(
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out)
return out
end

Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(b)
end

function thomas_algorithm_kernel!(
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function _set_interpolated_values_device!(
threads_s = (nthreads),
blocks_s = (nblocks),
)
call_post_op_callback() && post_op_callback(out)
end

# GPU, 3D case
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/remapping_interpolate_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ function interpolate_slab!(
threads_s = (nthreads),
blocks_s = (nblocks),
)
call_post_op_callback() && post_op_callback(output_array)

output_array .= Array(output_cuarray)
end
Expand Down
9 changes: 9 additions & 0 deletions ext/cuda/topologies_dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ function Topologies.dss_load_perimeter_data!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(perimeter_data)
return nothing
end

Expand Down Expand Up @@ -73,6 +74,7 @@ function Topologies.dss_unload_perimeter_data!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(data)
return nothing
end

Expand Down Expand Up @@ -123,6 +125,7 @@ function Topologies.dss_local!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(perimeter_data)
end
return nothing
end
Expand Down Expand Up @@ -213,6 +216,7 @@ function Topologies.dss_transform!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(perimeter_data)
end
return nothing
end
Expand Down Expand Up @@ -276,6 +280,7 @@ function Topologies.dss_untransform!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(data)
end
return nothing
end
Expand Down Expand Up @@ -333,6 +338,7 @@ function Topologies.dss_local_ghost!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(perimeter_data)
end
return nothing
end
Expand Down Expand Up @@ -396,6 +402,7 @@ function Topologies.fill_send_buffer!(
if synchronize
CUDA.synchronize(; blocking = true) # CUDA MPI uses a separate stream. This will synchronize across streams
end
call_post_op_callback() && post_op_callback(send_data)
end
return nothing
end
Expand Down Expand Up @@ -440,6 +447,7 @@ function Topologies.load_from_recv_buffer!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(perimeter_data)
end
return nothing
end
Expand Down Expand Up @@ -499,6 +507,7 @@ function Topologies.dss_ghost!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(perimeter_data)
end
return nothing
end
Expand Down
20 changes: 20 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2229,4 +2229,24 @@ include("mapreduce.jl")

include("struct_linear_indexing.jl")

"""
post_op_callback(result)
A callback that is called, if `ClimaCore.DataLayouts.call_post_op_callback() =
true`, on the result of every data operation.
There is purposely no implementation-- this is a debugging tool, and users may
want to check different things.
"""
function post_op_callback end

"""
call_post_op_callback()
Returns a Bool. Meant to be overloaded so that
`ClimaCore.DataLayouts.post_op_callback(data::AbstractData)` is called after
every data operation.
"""
call_post_op_callback() = false

end # module
7 changes: 6 additions & 1 deletion src/DataLayouts/copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ if VERSION ≥ v"1.11.0-beta"
dest::AbstractData{S},
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) where {S}
return Base.copyto!(dest, bc, device_dispatch(parent(dest)))
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
call_post_op_callback() && post_op_callback(dest)
dest
end
else
function Base.copyto!(
Expand All @@ -33,13 +35,15 @@ else
else
Base.copyto!(dest, bc, device_dispatch(parent(dest)))
end
call_post_op_callback() && post_op_callback(dest)
return dest
end
end

# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
copyto!(parent(dest), parent(src))
call_post_op_callback() && post_op_callback(dest)
return dest
end

Expand All @@ -58,6 +62,7 @@ function Base.copyto!(
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
call_post_op_callback() && post_op_callback(dest)
end

#####
Expand Down
2 changes: 2 additions & 0 deletions src/DataLayouts/fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ function Base.fill!(dest::AbstractData, val)
else
Base.fill!(dest, val, dev)
end
call_post_op_callback() && post_op_callback(dest)
dest
end

function Base.fill!(data::Union{IJFH, IJHF}, val, ::ToCPU)
Expand Down
1 change: 1 addition & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Fields
import ClimaComms
import MultiBroadcastFusion as MBF
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts: call_post_op_callback, post_op_callback
import ..DataLayouts:
DataLayouts,
AbstractData,
Expand Down
20 changes: 12 additions & 8 deletions src/Fields/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,21 @@ context.
See [`sum`](@ref) for the integral over the full domain.
"""
local_sum(
function local_sum(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
::ClimaComms.AbstractCPUDevice,
) = Base.reduce(
RecursiveApply.radd,
Base.Broadcast.broadcasted(
RecursiveApply.rmul,
Spaces.weighted_jacobian(axes(field)),
todata(field),
),
)
result = Base.reduce(
RecursiveApply.radd,
Base.Broadcast.broadcasted(
RecursiveApply.rmul,
Spaces.weighted_jacobian(axes(field)),
todata(field),
),
)
call_post_op_callback() && post_op_callback(result)
result
end
local_sum(field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}) =
local_sum(field, ClimaComms.device(axes(field)))
"""
Expand Down
1 change: 1 addition & 0 deletions src/Operators/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Base.Broadcast: Broadcasted

import ..slab, ..slab_args, ..column, ..column_args
import ClimaComms
import ..DataLayouts: call_post_op_callback, post_op_callback
import ..DataLayouts: DataLayouts, Data2D, DataSlab2D
import ..DataLayouts: vindex
import ..Geometry: Geometry, Covariant12Vector, Contravariant12Vector,
Expand Down
2 changes: 2 additions & 0 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3779,6 +3779,7 @@ function _serial_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int)
@inbounds for h in 1:Nh, j in 1:Nj, i in 1:Ni
apply_stencil!(space, field_out, bcs, (i, j, h), bounds)
end
call_post_op_callback() && post_op_callback(field_out)
return field_out
end

Expand All @@ -3793,6 +3794,7 @@ function _threaded_copyto!(field_out::Field, bc, Ni::Int, Nj::Int, Nh::Int)
end
end
end
call_post_op_callback() && post_op_callback(field_out)
return field_out
end

Expand Down
Loading

0 comments on commit b10fbfc

Please sign in to comment.