Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 8, 2025
1 parent 020fcd3 commit 832a20c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
34 changes: 17 additions & 17 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ function compile(job)
# TODO: on 1.9, this actually creates a context. cache those.
entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
)

Expand Down Expand Up @@ -451,7 +452,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))

trueidx = 1
allocs = Union{MLIR.IR.Value, Nothing}[]
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]

llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
i8 = MLIR.IR.Type(UInt8)
Expand All @@ -468,21 +469,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
trueidx += 1
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1)
push!(allocs, alloc)
push!(allocs, (alloc, argty))

sz = sizeof(a)
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
MLIR.Dialects.llvm.store(cdata, alloc)
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end

end

argidx = 1
for arg in values(seen)
@show arg
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
Expand All @@ -497,24 +495,17 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
push!(mlir_args, arg)

# Get the allocation corresponding to which arg we're doing
alloc = allocs[p[2]]::MLIR.IR.Value
alloc = allocs[p[2]][1]

# we need to now compute the offset in bytes of the path
julia_arg = allargs[p[2]]
@show p
@show julia_arg

offset = get_field_offset(typeof(julia_arg), p[3:end])
@show offset
MLIR.IR.block!(wrapbody) do
c_offset = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, [c_offset], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute(Int32[typemin(Int32)])), 1)
@show ptr
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1)
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
end

argidx += 1

push!(
aliases,
MLIR.IR.Attribute(
Expand All @@ -528,10 +519,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
),
),
)

argidx += 1
end
end

MLIR.IR.block!(wrapbody) do
for arg in allocs
if arg === nothing
continue
end
alloc, argty = arg
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
MLIR.Dialects.llvm.return_(nothing)
end
Expand All @@ -555,10 +556,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
)

argidx = 1
for arg in values(seen)
@show arg
if !(arg isa TracedRArray || arg isa TracedRNumber)
continue
end
Expand Down
36 changes: 33 additions & 3 deletions test/integration/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ function smul!(x)
end

@static if !Sys.isapple()

# Broken pending jll update
@static if false
@testset "Constant Op Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
Expand All @@ -87,4 +84,37 @@ end
end
end


function tuplef!(tup)
tup[1][] += 2
return nothing
end

function tuplef2!(tup)
tup[2][] *= tup[1]
return nothing
end

tuplef(a) = @cuda threads=1 tuplef!((a,))
tuplef2(a) = @cuda threads=1 tuplef2!((5, a))

@static if !Sys.isapple()
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
end

A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end

end
end

0 comments on commit 832a20c

Please sign in to comment.