-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linearize kernel args #497
base: main
Are you sure you want to change the base?
Conversation
Potentially useful: while messing with the kernel code, changing the second loop to:
produces the following MLIR IR mlir#tbaa_root = #llvm.tbaa_root<id = "custom_tbaa">
#tbaa_type_desc = #llvm.tbaa_type_desc<id = "custom_tbaa_addrspace(1)", members = {<#tbaa_root, 0>}>
#tbaa_tag = #llvm.tbaa_tag<base_type = #tbaa_type_desc, access_type = #tbaa_type_desc, offset = 0>
"builtin.module"() ({
"llvm.func"() <{CConv = #llvm.cconv<ccc>, function_type = !llvm.func<void (array<1 x array<1 x ptr<1>>>)>, linkage = #llvm.linkage<external>, sym_name = "_Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE", sym_visibility = "private", unnamed_addr = 1 : i64, visibility_ = 0 : i64}> ({
^bb0(%arg3: !llvm.array<1 x array<1 x ptr<1>>>):
%8 = "llvm.mlir.constant"() <{value = 2 : i64}> : () -> i64
%9 = "llvm.extractvalue"(%arg3) <{position = array<i64: 0>}> : (!llvm.array<1 x array<1 x ptr<1>>>) -> !llvm.array<1 x ptr<1>>
%10 = "llvm.extractvalue"(%9) <{position = array<i64: 0>}> : (!llvm.array<1 x ptr<1>>) -> !llvm.ptr<1>
%11 = "llvm.bitcast"(%10) : (!llvm.ptr<1>) -> !llvm.ptr<1>
"llvm.store"(%8, %11) <{alignment = 1 : i64, ordering = 0 : i64, tbaa = [#tbaa_tag]}> : (i64, !llvm.ptr<1>) -> ()
"llvm.return"() : () -> ()
}) : () -> ()
"llvm.func"() <{CConv = #llvm.cconv<ptx_kernelcc>, function_type = !llvm.func<void (array<1 x ptr<1>>, array<1 x ptr<1>>)>, linkage = #llvm.linkage<external>, sym_name = "##call__Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE#236", sym_visibility = "private", visibility_ = 0 : i64}> ({
^bb0(%arg1: !llvm.array<1 x ptr<1>>, %arg2: !llvm.array<1 x ptr<1>>):
"llvm.call"(%arg1) <{CConv = #llvm.cconv<ccc>, TailCallKind = #llvm.tailcallkind<none>, callee = @_Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE, fastmathFlags = #llvm.fastmath<none>, op_bundle_sizes = array<i32>, operandSegmentSizes = array<i32: 1, 0>}> : (!llvm.array<1 x ptr<1>>) -> ()
"llvm.return"() : () -> ()
}) : () -> ()
"func.func"() <{function_type = (tensor<i64>) -> tensor<i64>, sym_name = "main"}> ({
^bb0(%arg0: tensor<i64>):
%0 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
%1 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
%2 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
%3 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
%4 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
%5 = "stablehlo.constant"() <{value = dense<1> : tensor<i64>}> : () -> tensor<i64>
%6 = "stablehlo.constant"() <{value = dense<0> : tensor<i64>}> : () -> tensor<i64>
%7 = "enzymexla.kernel_call"(%0, %1, %2, %3, %4, %5, %6, %arg0) <{backend_config = "", fn = @"##call__Z2f_5TupleI13CuTracedArrayI5Int64Ll0ELl1E2__EE#236", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>]}> : (tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<i64>
"func.return"(%arg0) : (tensor<i64>) -> ()
}) : () -> ()
}) : () -> () which errors with
😅 |
Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com>
eeba06e
to
839050b
Compare
Perhaps a step in the right direction for offset computation 🤞. diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl
index 0801f813..7f8fa858 100644
--- a/ext/ReactantCUDAExt.jl
+++ b/ext/ReactantCUDAExt.jl
@@ -368,6 +368,27 @@ function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospeci
return Reactant.make_tracer(seen, x, path, mode; kwargs...)
end
+function get_field_offset(T::Type, path)
+ offset = 0
+ current_type = T
+
+ for field in path
+ # Get the field index
+ field_idx = findfirst(==(field), fieldnames(current_type))
+ if field_idx === nothing
+ error("Field $field not found in type $current_type")
+ end
+
+ # Add the offset of this field
+ offset += fieldoffset(current_type, field_idx)
+
+ # Update current_type to the field's type for next iteration
+ current_type = fieldtype(current_type, field_idx)
+ end
+
+ return offset
+end
+
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
args...;
convert=Val(false),
@@ -386,7 +407,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
rarrays = TracedRArray[]
fname = func.entry
-
+
wrapper_tys = MLIR.IR.Type[]
ctx = MLIR.IR.context()
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1))
@@ -395,9 +416,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
seen = Reactant.OrderedIdDict()
prev = Any[func.f, args...]
kernelargsym = gensym("kernelarg")
- make_tracer(seen, prev, (kernelargsym,), Reactant.TracedSetPath)
- wrapper_tys = fill(cullvm_ty, length(seen))
-
+ Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedSetPath)
+ linear_args = Reactant.TracedType[]
+ for v in values(seen)
+ v isa Reactant.TracedType || continue
+ push!(linear_args, v)
+ end
+ wrapper_tys = fill(cullvm_ty, length(linear_args))
+
sym_name = String(gensym("call_$fname"))
mod = MLIR.IR.mmodule()
CConv=MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel))
@@ -437,7 +463,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))), 1)
push!(allocs, alloc)
-
+
sz = sizeof(a)
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
@@ -449,54 +475,55 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end
argidx = 1
- for arg in values(seen)
+ for arg in values(linear_args)
for p in Reactant.TracedUtils.get_paths(arg)
if p[1] !== kernelargsym
continue
end
# Get the allocation corresponding to which arg we're doing
- alloc = allocs[p[2]]
+ alloc = allocs[p[2]-1] # some off-by-one confusion?
# we need to now compute the offset in bytes of the path
- offset = 0
- ptr = MLIR.gep!(alloc
+ offset = get_field_offset(typeof(args[p[2]-1]), p[3:end])
+ @warn offset
+ # ptr = MLIR.gep!(alloc
- store ptr = arg of wrapped index
+ # store ptr = arg of wrapped index
- argidx += 1
+ # argidx += 1
end
end
- if a isa CuTracedArray
- a =
- Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
- end
- if a isa TracedRArray || a isa TracedRNumber
- push!(rarrays, a)
- arg = a.mlir_data
- arg = Reactant.TracedUtils.transpose_val(arg)
- push!(restys, MLIR.IR.type(arg))
- push!(mlir_args, arg)
- push!(
- aliases,
- MLIR.IR.Attribute(
- MLIR.API.stablehloOutputOperandAliasGet(
- MLIR.IR.context(),
- length(wrapper_tys) == 1 ? 0 : 1,
- length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
- argidx - 1,
- 0,
- C_NULL,
- ),
- ),
- )
- push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))
- argidx += 1
- trueidx += 1
- continue
- end
- end
-
+ # if a isa CuTracedArray
+ # a =
+ # Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
+ # end
+ # if a isa TracedRArray || a isa TracedRNumber
+ # push!(rarrays, a)
+ # arg = a.mlir_data
+ # arg = Reactant.TracedUtils.transpose_val(arg)
+ # push!(restys, MLIR.IR.type(arg))
+ # push!(mlir_args, arg)
+ # push!(
+ # aliases,
+ # MLIR.IR.Attribute(
+ # MLIR.API.stablehloOutputOperandAliasGet(
+ # MLIR.IR.context(),
+ # length(wrapper_tys) == 1 ? 0 : 1,
+ # length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
+ # argidx - 1,
+ # 0,
+ # C_NULL,
+ # ),
+ # ),
+ # )
+ # push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))
+ # argidx += 1
+ # trueidx += 1
+ # continue
+ # end
+ # end
+
MLIR.IR.block!(wrapbody) do
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
MLIR.Dialects.llvm.return_(nothing)
|
this part is not really needed because |
They aren't all TracedR*, found out during debugging that the first value was a tuple or something like that, can't remember exactly. |
@wsmoses diff --git a/src/Tracing.jl b/src/Tracing.jl
index e00fdcb0..09838c1a 100644
--- a/src/Tracing.jl
+++ b/src/Tracing.jl
@@ -4,6 +4,7 @@
TracedToConcrete = 3
ArrayToConcrete = 4
TracedSetPath = 5
+ CUDATracedSetPath = 6
end
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@@ -460,10 +461,7 @@ function make_tracer(
end
return prev
end
- if mode == TracedSetPath
- if haskey(seen, prev)
- return seen[prev]
- end
+ if mode == TracedSetPath || mode == CUDATracedSetPath
res = if toscalar
TracedRNumber{T}((path,), nothing)
elseif tobatch !== nothing
@@ -471,6 +469,19 @@ function make_tracer(
else
TracedRArray{T,N}((path,), prev.mlir_data, size(prev))
end
+
+ # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+ if mode == CUDATracedSetPath
+ if haskey(seen, prev)
+ TracedUtils.set_paths!(seen[prev], (path,))
+ return seen[prev]
+ end
+ else # Normal TracedSetPath behavior
+ if haskey(seen, prev)
+ return seen[prev]
+ end
+ end
+
seen[prev] = res
return res
end
@@ -506,10 +517,7 @@ function make_tracer(
end
return prev
end
- if mode == TracedSetPath
- if haskey(seen, prev)
- return seen[prev]
- end
+ if mode == TracedSetPath || mode == CUDATracedSetPath
res = if toscalar
TracedRNumber{T}((path,), nothing)
elseif tobatch !== nothing
@@ -517,6 +525,19 @@ function make_tracer(
else
TracedRNumber{T}((path,), prev.mlir_data)
end
+
+ # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+ if mode == CUDATracedSetPath
+ if haskey(seen, prev)
+ TracedUtils.set_paths!(seen[prev], (path,))
+ return seen[prev]
+ end
+ else # Normal TracedSetPath behavior
+ if haskey(seen, prev)
+ return seen[prev]
+ end
+ end
+
seen[prev] = res
return res
end
@@ -546,9 +567,21 @@ function make_tracer(
end
return prev
end
- if mode == TracedSetPath
- haskey(seen, prev) && return seen[prev]
+ if mode == TracedSetPath || mode == CUDATracedSetPath
res = MissingTracedValue((path,))
+
+ # For CUDATracedSetPath, we want to set the path even if we've seen this object before
+ if mode == CUDATracedSetPath
+ if haskey(seen, prev)
+ TracedUtils.set_paths!(seen[prev], (path,))
+ return seen[prev]
+ end
+ else # Normal TracedSetPath behavior
+ if haskey(seen, prev)
+ return seen[prev]
+ end
+ end
+
seen[res] = res
return res
end |
Not quite, basically we want it to be equivalent to TracedTrack but we don’t exit early if we’ve seen the object before in the dict, so perhaps NoStopTracedTrack would be a good name |
Ah right, won't be for me tonight anymore. |
Yeah for sure get some sleep |
@@ -281,6 +281,7 @@ function compile(job) | |||
# TODO: on 1.9, this actually creates a context. cache those. | |||
entry = GPUCompiler.JuliaContext() do ctx | |||
mod, meta = GPUCompiler.compile( | |||
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true | |||
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false | |
:llvm, | |
job; | |
optimize=false, | |
cleanup=false, | |
validate=false, | |
libraries=false, |
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end |
tuplef(a) = @cuda threads=1 tuplef!((a,)) | ||
tuplef2(a) = @cuda threads=1 tuplef2!((5, a)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
tuplef(a) = @cuda threads=1 tuplef!((a,)) | |
tuplef2(a) = @cuda threads=1 tuplef2!((5, a)) | |
tuplef(a) = @cuda threads = 1 tuplef!((a,)) | |
tuplef2(a) = @cuda threads = 1 tuplef2!((5, a)) |
@testset "Structured Kernel Arguments" begin | ||
A = ConcreteRArray(fill(1)) | ||
if CUDA.functional() | ||
@jit tuplef(A) | ||
@test all(Array(A) .≈ 3) | ||
else | ||
@code_hlo optimize = :before_kernel tuplef(A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@testset "Structured Kernel Arguments" begin | |
A = ConcreteRArray(fill(1)) | |
if CUDA.functional() | |
@jit tuplef(A) | |
@test all(Array(A) .≈ 3) | |
else | |
@code_hlo optimize = :before_kernel tuplef(A) | |
@testset "Structured Kernel Arguments" begin | |
A = ConcreteRArray(fill(1)) | |
if CUDA.functional() | |
@jit tuplef(A) | |
@test all(Array(A) .≈ 3) | |
else | |
@code_hlo optimize = :before_kernel tuplef(A) | |
end | |
A = ConcreteRArray(fill(1)) | |
if CUDA.functional() | |
@jit tuplef2(A) | |
@test all(Array(A) .≈ 5) | |
else | |
@code_hlo optimize = :before_kernel tuplef2(A) | |
end |
|
||
A = ConcreteRArray(fill(1)) | ||
if CUDA.functional() | ||
@jit tuplef2(A) | ||
@test all(Array(A) .≈ 5) | ||
else | ||
@code_hlo optimize = :before_kernel tuplef2(A) | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
A = ConcreteRArray(fill(1)) | |
if CUDA.functional() | |
@jit tuplef2(A) | |
@test all(Array(A) .≈ 5) | |
else | |
@code_hlo optimize = :before_kernel tuplef2(A) | |
end | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/ext/ReactantCUDAExt.jl
Lines 355 to 367 in 832a20c
sz = sizeof(x) | |
ref = Ref(x) | |
GC.@preserve ref begin | |
ptr = Base.reinterpret(Ptr{UInt8}, Base.unsafe_convert(Ptr{Cvoid}, ref)) | |
vec = Vector{UInt8}(undef, sz) | |
for i in 1:sz | |
@inbounds vec[i] = Base.unsafe_load(ptr, i) | |
end | |
vec | |
end | |
end | |
function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/test/integration/cuda.jl
Lines 75 to 82 in 832a20c
@testset "Constant Op Kernel" begin | |
oA = collect(1:1:64) | |
A = Reactant.to_rarray(oA) | |
if CUDA.functional() | |
@jit smul!(A) | |
@test all(Array(A) .≈ oA .* 15) | |
else | |
@code_hlo optimize = :before_kernel smul!(A) |
findfirst(==(field), fieldnames(current_type)) | ||
end | ||
if field_idx === nothing | ||
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path") | |
error( | |
"Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path", | |
) |
@@ -384,20 +417,19 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
continue | ||
end | ||
# Per below we assume we can inline all other types directly in | ||
push!(wrapper_tys, cullvm_ty) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
@@ -426,20 +458,60 @@ | |||
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")) | |||
|
|||
trueidx = 1 | |||
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...] | |||
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[] | |
allocs = Union{Tuple{MLIR.IR.Value,MLIR.IR.Type},Nothing}[] |
|
||
# TODO check for only integer and explicitly non cutraced types | ||
MLIR.IR.block!(wrapbody) do | ||
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1)) | |
argty = MLIR.IR.Type( | |
MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx - 1) | |
) |
@@ -453,30 +525,20 @@ | |||
), | |||
), | |||
) | |||
push!(wrapargs, MLIR.IR.argument(wrapbody, argidx)) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1) | ||
MLIR.Dialects.llvm.store(cdata, alloc) | ||
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) | ||
push!(wrapargs, argres) | ||
end | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
alloc, argty = arg | ||
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1) | ||
push!(wrapargs, argres) | ||
end | ||
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[])) | |
MLIR.Dialects.llvm.call( | |
wrapargs, | |
MLIR.IR.Value[]; | |
callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), | |
op_bundle_sizes=MLIR.IR.Attribute(Int32[]), | |
) |
@@ -500,8 +562,14 @@ | |||
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), | |||
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases) | |
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), |
continue | ||
end | ||
arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx)) | ||
argidx+=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
argidx+=1 | |
argidx += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
Reactant.jl/ext/ReactantCUDAExt.jl
Lines 457 to 458 in 8127ed6
MLIR.IR.attr!(gpufunc, "CConv", MLIR.IR.Attribute(MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvC))) | |
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type")) |
no the ordered dict can contain anything (e.g. needed to stop infinite recursion for say a cyclic linkedlist) |
CC @wsmoses