Skip to content

Commit

Permalink
Fix kernel abstractions with Reactant GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 7, 2025
1 parent 26b9b70 commit 7c10e29
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
16 changes: 15 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,20 @@ struct LLVMFunc{F,tt}
entry::String
end

function Base.getproperty(f::LLVMFunc{F, tt}, sym::Symbol) where {F, tt}
if sym === :fun
f
else
Base.getfield(f, sym)
end
end

# TODO in the future we may want to avoid doing a second cufunction compilation
# for computing the thread/block count (or potentially do it ourselves).
@noinline function CUDA.launch_configuration(f::LLVMFunc{F, tt}; shmem::Union{Integer, Base.Callable}=0, max_threads::Integer=0) where {F, tt}
CUDA.launch_configuration(Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; shmem, max_threads)
end

const GPUCompiler = CUDA.GPUCompiler
const LLVM = GPUCompiler.LLVM

Expand Down Expand Up @@ -456,7 +470,7 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
)
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
end
return res
return Core.Typeof(res)(f, res.entry)
end

function Reactant.traced_type(
Expand Down
27 changes: 20 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ function should_rewrite_ft(@nospecialize(ft))
has_ancestor(mod, Reactant.TracedRandom)
return false
end
if string(mod) == "CUDA"
if ft.name.name == Symbol("#launch_configuration")
return false
end
end
end
end
# Don't rewrite Val
Expand Down Expand Up @@ -153,6 +158,8 @@ function should_rewrite_ft(@nospecialize(ft))
return false
end



# Default assume all functions need to be reactant-ified
return true
end
Expand Down Expand Up @@ -217,7 +224,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
end
if ft == typeof(Core._apply_iterate)
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir))
if should_rewrite_ft(ft)
if Base.invokelatest(should_rewrite_ft, ft)
if RT === Union{}
rep = Expr(
:call,
Expand All @@ -231,7 +238,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
return true, rep, Any
end
end
elseif should_rewrite_ft(ft)
elseif Base.invokelatest(should_rewrite_ft, ft)
if RT === Union{}
rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...)
return true, rep, Union{}
Expand All @@ -248,7 +255,7 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
if ft == typeof(Core.kwcall)
ft = sig.parameters[3]
end
if should_rewrite_ft(ft) && !is_reactant_method(omi)
if Base.invokelatest(should_rewrite_ft, ft) && !is_reactant_method(omi)
method = omi.def::Core.Method

min_world = Ref{UInt}(typemin(UInt))
Expand Down Expand Up @@ -479,9 +486,15 @@ function call_with_reactant_generator(
return stub(world, source, builtin_error)
end

method_error = :(throw(
MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world)
))
if guaranteed_error
method_error = :(throw(
MethodError($REDUB_ARGUMENTS_NAME[2], $REDUB_ARGUMENTS_NAME[3:end], $world)
))
else
method_error = :(throw(
MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world)
))
end

interp = ReactantInterpreter(; world)

Expand Down Expand Up @@ -675,7 +688,7 @@ function call_with_reactant_generator(
dict, make_oc = if Base.issingletontype(fn)
Base.Ref{Core.OpaqueClosure}(), make_oc_ref
else
Dict{args[1],Core.OpaqueClosure}(), make_oc_dict
Dict{fn,Core.OpaqueClosure}(), make_oc_dict
end

push!(oc_capture_vec, dict)
Expand Down

0 comments on commit 7c10e29

Please sign in to comment.