Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce post-operation callback #2115

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import CUDA
using CUDA
using CUDA: threadIdx, blockIdx, blockDim
import StaticArrays: SVector, SMatrix, SArray
import ClimaCore.DataLayouts: call_post_op_callback, post_op_callback
import ClimaCore.DataLayouts: mapreduce_cuda
import ClimaCore.DataLayouts: ToCUDA
import ClimaCore.DataLayouts: slab, column
Expand Down
18 changes: 12 additions & 6 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ if VERSION ≥ v"1.11.0-beta"
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand All @@ -39,10 +39,11 @@ if VERSION ≥ v"1.11.0-beta"
blocks_s = p.blocks,
)
end
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
return dest
end
else
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
function Base.copyto!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
Expand Down Expand Up @@ -74,6 +75,7 @@ else
)
end
end
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
return dest
end
end
Expand All @@ -85,7 +87,7 @@ end
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
to::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -95,13 +97,14 @@ function Base.copyto!(
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
end

# For field-vector operations
function DataLayouts.copyto_per_field!(
array::AbstractArray,
bc::Union{AbstractArray, Base.Broadcast.Broadcasted},
::ToCUDA,
to::ToCUDA,
)
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
# All field variables are treated separately, so
Expand All @@ -119,6 +122,7 @@ function DataLayouts.copyto_per_field!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, (array, bc, to), (;))
return array
end
function copyto_per_field_kernel!(array, bc, N)
Expand All @@ -133,7 +137,7 @@ end
function DataLayouts.copyto_per_field_scalar!(
array::AbstractArray,
bc::Base.Broadcast.Broadcasted{Style},
::ToCUDA,
to::ToCUDA,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -154,12 +158,13 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, (array, bc, to), (;))
return array
end
function DataLayouts.copyto_per_field_scalar!(
array::AbstractArray,
bc::Real,
::ToCUDA,
to::ToCUDA,
)
bc′ = DataLayouts.to_non_extruded_broadcasted(bc)
# All field variables are treated separately, so
Expand All @@ -177,6 +182,7 @@ function DataLayouts.copyto_per_field_scalar!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(array, (array, bc, to), (;))
return array
end
function copyto_per_field_kernel_0D!(array, bc, N)
Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function knl_fill_linear!(dest, val, us)
return nothing
end

function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
function Base.fill!(dest::AbstractData, bc, to::ToCUDA)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
args = (dest, bc, us)
Expand All @@ -41,5 +41,6 @@ function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
)
end
end
call_post_op_callback() && post_op_callback(dest, (dest, bc, to), (;))
return dest
end
17 changes: 15 additions & 2 deletions ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ function mapreduce_cuda(
)
pdata = parent(data)
S = eltype(data)
return DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(f(pdata))[1, :]))
call_post_op_callback() && post_op_callback(
data_out,
(f, op, data),
(; weighted_jacobian, opargs...),
)
return data_out
end

function mapreduce_cuda(
Expand Down Expand Up @@ -101,7 +107,14 @@ function mapreduce_cuda(
Val(shmemsize),
)
end
return DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))
data_out = DataLayouts.DataF{S}(Array(Array(reduce_cuda)[1, :]))

call_post_op_callback() && post_op_callback(
data_out,
(f, op, data),
(; weighted_jacobian, opargs...),
)
return data_out
end

function mapreduce_cuda_kernel!(
Expand Down
21 changes: 16 additions & 5 deletions ext/cuda/fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,57 @@ end

function Base.sum(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(identity, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() && post_op_callback(localsum[], (field, dev), (;))
return localsum[]
end

function Base.sum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.sum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localsum = mapreduce_cuda(fn, +, field, weighting = true)
ClimaComms.allreduce!(context, parent(localsum), +)
call_post_op_callback() &&
post_op_callback(localsum[], (fn, field, dev), (;))
return localsum[]
end

function Base.maximum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.maximum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(fn, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() &&
post_op_callback(localmax[], (fn, field, dev), (;))
return localmax[]
end

function Base.maximum(field::Field, ::ClimaComms.CUDADevice)
function Base.maximum(field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmax = mapreduce_cuda(identity, max, field)
ClimaComms.allreduce!(context, parent(localmax), max)
call_post_op_callback() &&
post_op_callback(localmax[], (fn, field, dev), (;))
return localmax[]
end

function Base.minimum(fn, field::Field, ::ClimaComms.CUDADevice)
function Base.minimum(fn, field::Field, dev::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(fn, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() &&
post_op_callback(localmin[], (fn, field, dev), (;))
return localmin[]
end

function Base.minimum(field::Field, ::ClimaComms.CUDADevice)
context = ClimaComms.context(axes(field))
localmin = mapreduce_cuda(identity, min, field)
ClimaComms.allreduce!(context, parent(localmin), min)
call_post_op_callback() &&
post_op_callback(localmin[], (fn, field, dev), (;))
return localmin[]
end

Expand Down
4 changes: 3 additions & 1 deletion ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ClimaCore.Utilities.UnrolledFunctions: unrolled_map
is_CuArray_type(::Type{T}) where {T <: CUDA.CuArray} = true

NVTX.@annotate function multiple_field_solve!(
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
cache,
x,
A,
Expand Down Expand Up @@ -48,6 +48,8 @@ NVTX.@annotate function multiple_field_solve!(
blocks_s = p.blocks,
always_inline = true,
)
call_post_op_callback() &&
post_op_callback(x, (dev, cache, x, A, b, x1), (;))
end

Base.@propagate_inbounds column_A(A::UniformScaling, i, j, h) = A
Expand Down
2 changes: 2 additions & 0 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() &&
post_op_callback(x, (device, cache, x, A, b), (;))
end

function single_field_solve_kernel!(device, cache, x, A, b, us)
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out, (out, bc), (;))
return out
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
Expand Down
7 changes: 6 additions & 1 deletion ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import ClimaComms
using CUDA: @cuda

function column_reduce_device!(
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
f::F,
transform::T,
output,
Expand Down Expand Up @@ -40,6 +40,11 @@ function column_reduce_device!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(
output,
(dev, f, transform, output, input, init, space),
(;),
)
end

function column_accumulate_device!(
Expand Down
1 change: 1 addition & 0 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function Base.copyto!(
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(out, (out, sbc), (;))
return out
end

Expand Down
3 changes: 2 additions & 1 deletion ext/cuda/operators_thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ClimaCore.Operators:
column_thomas_solve!, thomas_algorithm_kernel!, thomas_algorithm!
import CUDA
using CUDA: @cuda
function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
function column_thomas_solve!(dev::ClimaComms.CUDADevice, A, b)
us = UniversalSize(Fields.field_values(A))
args = (A, b, us)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
Expand All @@ -18,6 +18,7 @@ function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
threads_s = p.threads,
blocks_s = p.blocks,
)
call_post_op_callback() && post_op_callback(b, (dev, A, b), (;))
end

function thomas_algorithm_kernel!(
Expand Down
16 changes: 15 additions & 1 deletion ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function _set_interpolated_values_device!(
interpolation_matrix,
vert_interpolation_weights::AbstractArray,
vert_bounding_indices::AbstractArray,
::ClimaComms.CUDADevice,
dev::ClimaComms.CUDADevice,
)
# FIXME: Avoid allocation of tuple
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
Expand All @@ -33,6 +33,20 @@ function _set_interpolated_values_device!(
threads_s = (nthreads),
blocks_s = (nblocks),
)
call_post_op_callback() && post_op_callback(
out,
(
out,
fields,
scratch_field_values,
local_horiz_indices,
interpolation_matrix,
vert_interpolation_weights,
vert_bounding_indices,
dev,
),
(;),
)
end

# GPU, 3D case
Expand Down
5 changes: 5 additions & 0 deletions ext/cuda/remapping_interpolate_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ function interpolate_slab!(
threads_s = (nthreads),
blocks_s = (nblocks),
)
call_post_op_callback() && post_op_callback(
output_array,
(output_array, field, slab_indices, weights, device),
(;),
)

output_array .= Array(output_cuarray)
end
Expand Down
Loading
Loading