Skip to content

Commit

Permalink
Introduce post-operation callback
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 8, 2025
1 parent 81a8fa1 commit 1531bf4
Show file tree
Hide file tree
Showing 25 changed files with 192 additions and 36 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import CUDA
using CUDA
using CUDA: threadIdx, blockIdx, blockDim
import StaticArrays: SVector, SMatrix, SArray
import ClimaCore.DataLayouts: call_post_op_callback, post_op_callback
import ClimaCore.DataLayouts: mapreduce_cuda
import ClimaCore.DataLayouts: ToCUDA
import ClimaCore.DataLayouts: slab, column
Expand Down
18 changes: 12 additions & 6 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ if VERSION ≥ v"1.11.0-beta"
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand All @@ -39,10 +39,11 @@ if VERSION ≥ v"1.11.0-beta"
blocks_s = p.blocks,
)
end
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
return dest
end
else
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand Down Expand Up @@ -74,6 +75,7 @@ else
)
end
end
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
return dest
end
end
Expand All @@ -85,7 +87,7 @@ end
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
to::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -95,13 +97,14 @@ function Base.copyto!(
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
end

# For field-vector operations
function DataLayouts.copyto_per_field!(
array::AbstractArray,
bc::Union{AbstractArray, Base.Broadcast.Broadcasted},
::ToCUDA,
to::ToCUDA,
)
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
# All field variables are treated separately, so
Expand All @@ -119,6 +122,7 @@ function DataLayouts.copyto_per_field!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, (array, bc, to), (;))
return array
end
function copyto_per_field_kernel!(array, bc, N)
Expand All @@ -133,7 +137,7 @@ end
function DataLayouts.copyto_per_field_scalar!(
array::AbstractArray,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
to::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -154,12 +158,13 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, (array, bc, to), (;))
return array
end
function DataLayouts.copyto_per_field_scalar!(
array::AbstractArray,
bc::Real,
::ToCUDA,
to::ToCUDA,
)
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
# All field variables are treated separately, so
Expand All @@ -177,6 +182,7 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, (array, bc, to), (;))
return array
end
function copyto_per_field_kernel_0D!(array, bc, N)
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function knl_fill_linear!(dest, val, us)
return nothing
end

function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
function Base.fill!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
args = (dest, bc, us)
Expand All @@ -41,5 +41,6 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
)
end
end
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
return dest
end
17 changes: 15 additions & 2 deletions ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ function mapreduce_cuda(
)
pdata = parent(data)
S = eltype(data)
return DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
call_post_op_callback() && post_op_callback(
data_out,
(f, op, data),
(; weighted_jacobian, opargs...),
)
return data_out
end

function mapreduce_cuda(
Expand Down Expand Up @@ -101,7 +107,14 @@ function mapreduce_cuda(
Val(shmemsize),
)
end
return DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))

call_post_op_callback() && post_op_callback(
data_out,
(f, op, data),
(; weighted_jacobian, opargs...),
)
return data_out
end

function mapreduce_cuda_kernel!(
Expand Down
21 changes: 16 additions & 5 deletions ext/cuda/fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,57 @@ end

function Base.sum(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(identity, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() && post_op_callback(localsum[], (field, dev), (;))
return localsum[]
end

function Base.sum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.sum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(fn, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() &&
post_op_callback(localsum[], (fn, field, dev), (;))
return localsum[]
end

function Base.maximum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.maximum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(fn, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() &&
post_op_callback(localmax[], (fn, field, dev), (;))
return localmax[]
end

function Base.maximum(field::Field, ::ClimaComms.CUDADevice)
function Base.maximum(field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(identity, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() &&
post_op_callback(localmax[], (fn, field, dev), (;))
return localmax[]
end

function Base.minimum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.minimum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(fn, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() &&
post_op_callback(localmin[], (fn, field, dev), (;))
return localmin[]
end

function Base.minimum(field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(identity, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() &&
post_op_callback(localmin[], (fn, field, dev), (;))
return localmin[]
end

Expand Down
4 changes: 3 additions & 1 deletion ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ClimaCore.Utilities.UnrolledFunctions: unrolled_map
is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true

NVTX.@annotate function multiple_field_solve!(
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
cache,
x,
A,
Expand Down Expand Up @@ -48,6 +48,8 @@ NVTX.@annotate function multiple_field_solve!(
blocks_s = p.blocks,
always_inline = true,
)
call_post_op_callback() &&
post_op_callback(x, (dev, cache, x, A, b, x1), (;))
end

Base.@propagate_inbounds column_A(A::UniformScaling, i, j, h) = A
Expand Down
2 changes: 2 additions & 0 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() &&
post_op_callback(x, (device, cache, x, A, b), (;))
end

function single_field_solve_kernel!(device, cache, x, A, b, us)
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out, (out, bc), (;))
return out
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
Expand Down
7 changes: 6 additions & 1 deletion ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ClimaComms
using CUDA: @cuda

function column_reduce_device!(
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
f::F,
transform::T,
output,
Expand Down Expand Up @@ -40,6 +40,11 @@ function column_reduce_device!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(
output,
(dev, f, transform, output, input, init, space),
(;),
)
end

function column_accumulate_device!(
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out, (out, sbc), (;))
return out
end

Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/operators_thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ClimaCore.Operators:
column_thomas_solve!, thomas_algorithm_kernel!, thomas_algorithm!
import CUDA
using CUDA: @cuda
function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
function column_thomas_solve!(dev::ClimaComms.CUDADevice, A, b)
us = UniversalSize(Fields.field_values(A))
args = (A, b, us)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
Expand All @@ -18,6 +18,7 @@ function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(b, (dev, A, b), (;))
end

function thomas_algorithm_kernel!(
Expand Down
16 changes: 15 additions & 1 deletion ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function _set_interpolated_values_device!(
interpolation_matrix,
vert_interpolation_weights::AbstractArray,
vert_bounding_indices::AbstractArray,
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
# FIXME: Avoid allocation of tuple
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
Expand All @@ -33,6 +33,20 @@ function _set_interpolated_values_device!(
threads_s = (nthreads),
blocks_s = (nblocks),
)
call_post_op_callback() && post_op_callback(
out,
(
out,
fields,
scratch_field_values,
local_horiz_indices,
interpolation_matrix,
vert_interpolation_weights,
vert_bounding_indices,
dev,
),
(;),
)
end

# GPU, 3D case
Expand Down
5 changes: 5 additions & 0 deletions ext/cuda/remapping_interpolate_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ function interpolate_slab!(
threads_s = (nthreads),
blocks_s = (nblocks),
)
call_post_op_callback() && post_op_callback(
output_array,
(output_array, field, slab_indices, weights, device),
(;),
)

output_array .= Array(output_cuarray)
end
Expand Down
Loading

0 comments on commit 1531bf4

Please sign in to comment.