Skip to content

Commit

Permalink
Cuv2 (#423)
Browse files Browse the repository at this point in the history
* Kernel-supporting jll

* fix rulescc

* adapt to hedron dep

* init target

* fixup

* additional fixups

* fixup

* fix

* registry utils

* callname

* reg

* fix

* fix bld

* cleanup

* no pip

* fix

* force rules python to older version before bug

* fixup jll

* with proto

* fix

* fix

* Update WORKSPACE

* more deps for apple

* bump

* fix

* workspace bump

* workspace

* Update Compiler.jl

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* Update Project.toml

* Update ReactantCUDAExt.jl

* Update Project.toml

* Update Project.toml

* fix

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* Update ReactantCUDAExt.jl

* Update cuda.jl

* Update cuda.jl

* Update cuda.jl

* Cuda kernel v2

* Update Project.toml

* Update API.cpp

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: William Moses <wsmoses@cyclops.juliacomputing.io>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 24, 2024
1 parent 057e6b8 commit 925544f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
3 changes: 2 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ function __init__()
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
return Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
return nothing
end

end # module ReactantCUDAExt
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
kern = "lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
Expand Down
17 changes: 12 additions & 5 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@ using Reactant
using Test
using CUDA

using Reactant_jll
@show Reactant_jll.libReactantExtra_path

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
#i = threadIdx().x
#x[i] *= x[i]
#@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n",
# 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z)
#x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z)

# sync_threads()
return nothing
end
Expand All @@ -18,9 +25,9 @@ end
@testset "Square Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
@show @code_hlo optimize = false square!(A)
@show @code_hlo optimize = :before_kernel square!(A)
@show @code_hlo square!(A)
# @show @code_hlo optimize = false square!(A)
# @show @code_hlo optimize = :before_kernel square!(A)
# @show @code_hlo square!(A)
func! = @compile square!(A)
func!(A)
@show A
Expand Down

0 comments on commit 925544f

Please sign in to comment.