Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linearize kernel args #497

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

linearize kernel args #497

wants to merge 9 commits into from

Conversation

mofeing
Copy link
Collaborator

@mofeing mofeing commented Jan 8, 2025

ext/ReactantCUDAExt.jl Outdated Show resolved Hide resolved
@jumerckx
Copy link
Collaborator

jumerckx commented Jan 8, 2025

Potentially useful: while messing with the kernel code, changing the second loop to:

    for a in values(seen)
        a isa Reactant.TracedType || continue
        push!(rarrays, a)
        arg = a.mlir_data
        arg = Reactant.TracedUtils.transpose_val(arg)
        push!(restys, MLIR.IR.type(arg))
        push!(mlir_args, arg)
        push!(
            aliases,
            MLIR.IR.Attribute(
                MLIR.API.stablehloOutputOperandAliasGet(
                    MLIR.IR.context(),
                    length(wrapper_tys) == 1 ? 0 : 1,
                    length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
                    argidx - 1,
                    0,
                    C_NULL,
                ),
            ),
        )
        push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))
        argidx += 1
        trueidx += 1
    end

produces the following MLIR IR

mlir
#tbaa_root = #llvm.tbaa_root<id = "custom_tbaa">
#tbaa_type_desc = #llvm.tbaa_type_desc<id = "custom_tbaa_addrspace(1)", members = {<#tbaa_root, 0>}>
#tbaa_tag = #llvm.tbaa_tag<base_type = #tbaa_type_desc, access_type = #tbaa_type_desc, offset = 0>
"builtin.module"() ({
  "llvm.func"() <{CConv = #llvm.cconv<ccc>, function_type = !llvm.func<void (array<1 x array<1 x ptr<1>>>)>, linkage = #llvm.linkage<external>, sym_name = "_Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE", sym_visibility = "private", unnamed_addr = 1 : i64, visibility_ = 0 : i64}> ({
  ^bb0(%arg3: !llvm.array<1 x array<1 x ptr<1>>>):
    %8 = "llvm.mlir.constant"() <{value = 2 : i64}> : () -> i64
    %9 = "llvm.extractvalue"(%arg3) <{position = array<i64: 0>}> : (!llvm.array<1 x array<1 x ptr<1>>>) -> !llvm.array<1 x ptr<1>>
    %10 = "llvm.extractvalue"(%9) <{position = array<i64: 0>}> : (!llvm.array<1 x ptr<1>>) -> !llvm.ptr<1>
    %11 = "llvm.bitcast"(%10) : (!llvm.ptr<1>) -> !llvm.ptr<1>
    "llvm.store"(%8, %11) <{alignment = 1 : i64, ordering = 0 : i64, tbaa = [#tbaa_tag]}> : (i64, !llvm.ptr<1>) -> ()
    "llvm.return"() : () -> ()
  }) : () -> ()
  "llvm.func"() <{CConv = #llvm.cconv<ptx_kernelcc>, function_type = !llvm.func<void (array<1 x ptr<1>>, array<1 x ptr<1>>)>, linkage = #llvm.linkage<external>, sym_name = "##call__Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE#236", sym_visibility = "private", visibility_ = 0 : i64}> ({
  ^bb0(%arg1: !llvm.array<1 x ptr<1>>, %arg2: !llvm.array<1 x ptr<1>>):
    "llvm.call"(%arg1) <{CConv = #llvm.cconv<ccc>, TailCallKind = #llvm.tailcallkind<none>, callee = @_Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE, fastmathFlags = #llvm.fastmath<none>, op_bundle_sizes = array<i32>, operandSegmentSizes = array<i32: 1, 0>}> : (!llvm.array<1 x ptr<1>>) -> ()
    "llvm.return"() : () -> ()
  }) : () -> ()
  "func.func"() <{function_type = (tensor<i64>) -> tensor<i64>, sym_name = "main"}> ({
  ^bb0(%arg0: tensor<i64>):
    %0 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
    %1 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
    %2 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
    %3 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
    %4 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
    %5 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
    %6 = "stablehlo.constant"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
    %7 = "enzymexla.kernel_call"(%0, %1, %2, %3, %4, %5, %6, %arg0) <{backend_config = "", fn = @"##call__Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE#236", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>]}> : (tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<i64>
    "func.return"(%arg0) : (tensor<i64>) -> ()
  }) : () -> ()
}) : () -> ()

which errors with

error: 'llvm.call' op operand type mismatch for operand 0: '!llvm.array<1 x ptr<1>>' != '!llvm.array<1 x array<1 x ptr<1>>>'

😅

@wsmoses wsmoses changed the base branch from ka to main January 8, 2025 20:24
mofeing and others added 2 commits January 8, 2025 15:25
Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com>
@jumerckx
Copy link
Collaborator

jumerckx commented Jan 8, 2025

Perhaps a step in the right direction for offset computation 🤞.
But my time's up for today ;)

diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl
index 0801f813..7f8fa858 100644
--- a/ext/ReactantCUDAExt.jl
+++ b/ext/ReactantCUDAExt.jl
@@ -368,6 +368,27 @@ function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospeci
     return Reactant.make_tracer(seen, x, path, mode; kwargs...)
 end
 
+function get_field_offset(T::Type, path)
+    offset = 0
+    current_type = T
+
+    for field in path
+        # Get the field index
+        field_idx = findfirst(==(field), fieldnames(current_type))
+        if field_idx === nothing
+            error("Field $field not found in type $current_type")
+        end
+
+        # Add the offset of this field
+        offset += fieldoffset(current_type, field_idx)
+
+        # Update current_type to the field's type for next iteration
+        current_type = fieldtype(current_type, field_idx)
+    end
+
+    return offset
+end
+
 Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
     args...;
     convert=Val(false),
@@ -386,7 +407,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
     rarrays = TracedRArray[]
 
     fname = func.entry
-    
+
     wrapper_tys = MLIR.IR.Type[]
     ctx = MLIR.IR.context()
     cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1))
@@ -395,9 +416,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
     seen = Reactant.OrderedIdDict()
     prev = Any[func.f, args...]
     kernelargsym = gensym("kernelarg")
-    make_tracer(seen, prev, (kernelargsym,), Reactant.TracedSetPath)
-    wrapper_tys = fill(cullvm_ty, length(seen))
-    
+    Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedSetPath)
+    linear_args = Reactant.TracedType[]
+    for v in values(seen)
+        v isa Reactant.TracedType || continue
+        push!(linear_args, v)
+    end
+    wrapper_tys = fill(cullvm_ty, length(linear_args))
+
     sym_name = String(gensym("call_$fname"))
     mod = MLIR.IR.mmodule()
     CConv=MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel))
@@ -437,7 +463,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
             c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
             alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))), 1)
             push!(allocs, alloc)
-           
+
             sz = sizeof(a)
             array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
             cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
@@ -449,54 +475,55 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
     end
 
     argidx = 1
-    for arg in values(seen)
+    for arg in values(linear_args)
         for p in Reactant.TracedUtils.get_paths(arg)
             if p[1] !== kernelargsym
                 continue
             end
             # Get the allocation corresponding to which arg we're doing
-            alloc = allocs[p[2]]
+            alloc = allocs[p[2]-1] # some off-by-one confusion?
 
             # we need to now compute the offset in bytes of the path
-            offset = 0
-            ptr = MLIR.gep!(alloc
+            offset = get_field_offset(typeof(args[p[2]-1]), p[3:end])
+            @warn offset
+            # ptr = MLIR.gep!(alloc
 
-            store ptr = arg of wrapped index
+            # store ptr = arg of wrapped index
 
-            argidx += 1
+            # argidx += 1
         end
     end
 
-        if a isa CuTracedArray
-            a =
-                Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
-        end
-        if a isa TracedRArray || a isa TracedRNumber
-            push!(rarrays, a)
-            arg = a.mlir_data
-            arg = Reactant.TracedUtils.transpose_val(arg)
-            push!(restys, MLIR.IR.type(arg))
-            push!(mlir_args, arg)
-            push!(
-                aliases,
-                MLIR.IR.Attribute(
-                    MLIR.API.stablehloOutputOperandAliasGet(
-                        MLIR.IR.context(),
-                        length(wrapper_tys) == 1 ? 0 : 1,
-                        length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
-                        argidx - 1,
-                        0,
-                        C_NULL,
-                    ),
-                ),
-            )
-            push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))
-            argidx += 1
-            trueidx += 1
-            continue
-        end
-    end
-        
+    #     if a isa CuTracedArray
+    #         a =
+    #             Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
+    #     end
+    #     if a isa TracedRArray || a isa TracedRNumber
+    #         push!(rarrays, a)
+    #         arg = a.mlir_data
+    #         arg = Reactant.TracedUtils.transpose_val(arg)
+    #         push!(restys, MLIR.IR.type(arg))
+    #         push!(mlir_args, arg)
+    #         push!(
+    #             aliases,
+    #             MLIR.IR.Attribute(
+    #                 MLIR.API.stablehloOutputOperandAliasGet(
+    #                     MLIR.IR.context(),
+    #                     length(wrapper_tys) == 1 ? 0 : 1,
+    #                     length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
+    #                     argidx - 1,
+    #                     0,
+    #                     C_NULL,
+    #                 ),
+    #             ),
+    #         )
+    #         push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))
+    #         argidx += 1
+    #         trueidx += 1
+    #         continue
+    #     end
+    # end
+
     MLIR.IR.block!(wrapbody) do
         MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
         MLIR.Dialects.llvm.return_(nothing)

@mofeing
Copy link
Collaborator Author

mofeing commented Jan 8, 2025

+    linear_args = Reactant.TracedType[]
+    for v in values(seen)
+        v isa Reactant.TracedType || continue
+        push!(linear_args, v)
+    end

this part is not really needed because seen is already ordered and all its values are TracedRArray or TracedRNumber right?

@jumerckx
Copy link
Collaborator

jumerckx commented Jan 8, 2025

+    linear_args = Reactant.TracedType[]
+    for v in values(seen)
+        v isa Reactant.TracedType || continue
+        push!(linear_args, v)
+    end

this part is not really needed because seen is already ordered and all its values are TracedRArray or TracedRNumber right?

They aren't all TracedR*, found out during debugging that the first value was a tuple or something like that, can't remember exactly.

@jumerckx
Copy link
Collaborator

jumerckx commented Jan 8, 2025

@wsmoses CUDATracedSetPath, if I understood correctly what needs to happen.

diff --git a/src/Tracing.jl b/src/Tracing.jl
index e00fdcb0..09838c1a 100644
--- a/src/Tracing.jl
+++ b/src/Tracing.jl
@@ -4,6 +4,7 @@
     TracedToConcrete = 3
     ArrayToConcrete = 4
     TracedSetPath = 5
+    CUDATracedSetPath = 6
 end
 
 for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@@ -460,10 +461,7 @@ function make_tracer(
         end
         return prev
     end
-    if mode == TracedSetPath
-        if haskey(seen, prev)
-            return seen[prev]
-        end
+    if mode == TracedSetPath || mode == CUDATracedSetPath
         res = if toscalar
             TracedRNumber{T}((path,), nothing)
         elseif tobatch !== nothing
@@ -471,6 +469,19 @@ function make_tracer(
         else
             TracedRArray{T,N}((path,), prev.mlir_data, size(prev))
         end
+
+        # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+        if mode == CUDATracedSetPath
+            if haskey(seen, prev)
+                TracedUtils.set_paths!(seen[prev], (path,))
+                return seen[prev]
+            end
+        else # Normal TracedSetPath behavior
+            if haskey(seen, prev)
+                return seen[prev]
+            end
+        end
+
         seen[prev] = res
         return res
     end
@@ -506,10 +517,7 @@ function make_tracer(
         end
         return prev
     end
-    if mode == TracedSetPath
-        if haskey(seen, prev)
-            return seen[prev]
-        end
+    if mode == TracedSetPath || mode == CUDATracedSetPath
         res = if toscalar
             TracedRNumber{T}((path,), nothing)
         elseif tobatch !== nothing
@@ -517,6 +525,19 @@ function make_tracer(
         else
             TracedRNumber{T}((path,), prev.mlir_data)
         end
+
+        # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+        if mode == CUDATracedSetPath
+            if haskey(seen, prev)
+                TracedUtils.set_paths!(seen[prev], (path,))
+                return seen[prev]
+            end
+        else # Normal TracedSetPath behavior
+            if haskey(seen, prev)
+                return seen[prev]
+            end
+        end
+
         seen[prev] = res
         return res
     end
@@ -546,9 +567,21 @@ function make_tracer(
         end
         return prev
     end
-    if mode == TracedSetPath
-        haskey(seen, prev) && return seen[prev]
+    if mode == TracedSetPath || mode == CUDATracedSetPath
         res = MissingTracedValue((path,))
+
+        # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+        if mode == CUDATracedSetPath
+            if haskey(seen, prev)
+                TracedUtils.set_paths!(seen[prev], (path,))
+                return seen[prev]
+            end
+        else # Normal TracedSetPath behavior
+            if haskey(seen, prev)
+                return seen[prev]
+            end
+        end
+
         seen[res] = res
         return res
     end

@wsmoses
Copy link
Member

wsmoses commented Jan 8, 2025

@wsmoses CUDATracedSetPath, if I understood correctly what needs to happen.

diff --git a/src/Tracing.jl b/src/Tracing.jl
index e00fdcb0..09838c1a 100644
--- a/src/Tracing.jl
+++ b/src/Tracing.jl
@@ -4,6 +4,7 @@
     TracedToConcrete = 3
     ArrayToConcrete = 4
     TracedSetPath = 5
+    CUDATracedSetPath = 6
 end
 
 for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@@ -460,10 +461,7 @@ function make_tracer(
         end
         return prev
     end
-    if mode == TracedSetPath
-        if haskey(seen, prev)
-            return seen[prev]
-        end
+    if mode == TracedSetPath || mode == CUDATracedSetPath
         res = if toscalar
             TracedRNumber{T}((path,), nothing)
         elseif tobatch !== nothing
@@ -471,6 +469,19 @@ function make_tracer(
         else
             TracedRArray{T,N}((path,), prev.mlir_data, size(prev))
         end
+
+        # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+        if mode == CUDATracedSetPath
+            if haskey(seen, prev)
+                TracedUtils.set_paths!(seen[prev], (path,))
+                return seen[prev]
+            end
+        else # Normal TracedSetPath behavior
+            if haskey(seen, prev)
+                return seen[prev]
+            end
+        end
+
         seen[prev] = res
         return res
     end
@@ -506,10 +517,7 @@ function make_tracer(
         end
         return prev
     end
-    if mode == TracedSetPath
-        if haskey(seen, prev)
-            return seen[prev]
-        end
+    if mode == TracedSetPath || mode == CUDATracedSetPath
         res = if toscalar
             TracedRNumber{T}((path,), nothing)
         elseif tobatch !== nothing
@@ -517,6 +525,19 @@ function make_tracer(
         else
             TracedRNumber{T}((path,), prev.mlir_data)
         end
+
+        # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+        if mode == CUDATracedSetPath
+            if haskey(seen, prev)
+                TracedUtils.set_paths!(seen[prev], (path,))
+                return seen[prev]
+            end
+        else # Normal TracedSetPath behavior
+            if haskey(seen, prev)
+                return seen[prev]
+            end
+        end
+
         seen[prev] = res
         return res
     end
@@ -546,9 +567,21 @@ function make_tracer(
         end
         return prev
     end
-    if mode == TracedSetPath
-        haskey(seen, prev) && return seen[prev]
+    if mode == TracedSetPath || mode == CUDATracedSetPath
         res = MissingTracedValue((path,))
+
+        # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+        if mode == CUDATracedSetPath
+            if haskey(seen, prev)
+                TracedUtils.set_paths!(seen[prev], (path,))
+                return seen[prev]
+            end
+        else # Normal TracedSetPath behavior
+            if haskey(seen, prev)
+                return seen[prev]
+            end
+        end
+
         seen[res] = res
         return res
     end

Not quite, basically we want it to be equivalent to TracedTrack but we don’t exit early if we’ve seen the object before in the dict, so perhaps NoStopTracedTrack would be a good name

@jumerckx
Copy link
Collaborator

jumerckx commented Jan 8, 2025

Not quite, basically we want it to be equivalent to TracedTrack

Ah right, won't be for me tonight anymore.

@wsmoses
Copy link
Member

wsmoses commented Jan 8, 2025

Yeah for sure get some sleep

@@ -281,6 +281,7 @@ function compile(job)
# TODO: on 1.9, this actually creates a context. cache those.
entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
:llvm,
job;
optimize=false,
cleanup=false,
validate=false,
libraries=false,

Comment on lines 85 to 86
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end

Comment on lines +98 to +99
tuplef(a) = @cuda threads=1 tuplef!((a,))
tuplef2(a) = @cuda threads=1 tuplef2!((5, a))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
tuplef(a) = @cuda threads=1 tuplef!((a,))
tuplef2(a) = @cuda threads=1 tuplef2!((5, a))
tuplef(a) = @cuda threads = 1 tuplef!((a,))
tuplef2(a) = @cuda threads = 1 tuplef2!((5, a))

Comment on lines +102 to +108
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
@testset "Structured Kernel Arguments" begin
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef(A)
@test all(Array(A) .≈ 3)
else
@code_hlo optimize = :before_kernel tuplef(A)
end
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end

Comment on lines +110 to +119

A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end

end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
A = ConcreteRArray(fill(1))
if CUDA.functional()
@jit tuplef2(A)
@test all(Array(A) .≈ 5)
else
@code_hlo optimize = :before_kernel tuplef2(A)
end
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

sz = sizeof(x)
ref = Ref(x)
GC.@preserve ref begin
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref))
vec = Vector{UInt8}(undef, sz)
for i in 1:sz
@inbounds vec[i] = Base.unsafe_load(ptr, i)
end
vec
end
end
function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

@testset "Constant Op Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
if CUDA.functional()
@jit smul!(A)
@test all(Array(A) .≈ oA .* 15)
else
@code_hlo optimize = :before_kernel smul!(A)

findfirst(==(field), fieldnames(current_type))
end
if field_idx === nothing
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
error(
"Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path",
)

@@ -384,20 +417,19 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

continue
end
# Per below we assume we can inline all other types directly in
push!(wrapper_tys, cullvm_ty)
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@@ -426,20 +458,60 @@
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))

trueidx = 1
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[]


# TODO check for only integer and explicitly non cutraced types
MLIR.IR.block!(wrapbody) do
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
argty = MLIR.IR.Type(
MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1)
)

@@ -453,30 +525,20 @@
),
),
)
push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
MLIR.Dialects.llvm.store(cdata, alloc)
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

alloc, argty = arg
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
push!(wrapargs, argres)
end
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
MLIR.Dialects.llvm.call(
wrapargs,
MLIR.IR.Value[];
callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)),
op_bundle_sizes=MLIR.IR.Attribute(Int32[]),
)

@@ -500,8 +562,14 @@
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),

continue
end
arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx))
argidx+=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
argidx+=1
argidx += 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC)))
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))

@wsmoses
Copy link
Member

wsmoses commented Jan 9, 2025

get_field_offset(typeof(args[p[2]-1]), p[3:end])

no the ordered dict can contain anything (e.g. needed to stop infinite recursion for say a cyclic linkedlist)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants