Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with vcat #139

Closed
gdalle opened this issue Oct 1, 2024 · 2 comments · Fixed by #163
Closed

Error with vcat #139

gdalle opened this issue Oct 1, 2024 · 2 comments · Fixed by #163
Labels
good first issue Good for newcomers

Comments

@gdalle
Copy link

gdalle commented Oct 1, 2024

The root issue seems to be the conversion from a 0-sized traced array to a float. Is it supposed to be possible?

julia> using Reactant

julia> f(x) = vcat(x, x)
f (generic function with 1 method)

julia> xr = ConcreteRArray([1.0])
1-element ConcreteRArray{Float64, 1}:
 1.0

julia> f(xr)
2-element ConcreteRArray{Float64, 1}:
 1.0
 1.0

julia> @compile f(xr)
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007fe657a72720.
│ Invocation resulted in scalar indexing of a TracedRArray.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on device, but very slowly on the CPU,
│ and require expensive copies and synchronization each time and therefore should be avoided.
└ @ Reactant ~/.julia/packages/Reactant/MAkBF/src/TracedRArray.jl:53
ERROR: MethodError: Cannot `convert` an object of type Reactant.TracedRArray{Float64, 0} to an object of type Float64

Closest candidates are:
  convert(::Type{T}, ::T) where T<:Number
   @ Base number.jl:6
  convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat
   @ LLVM ~/.julia/packages/LLVM/joxPv/src/core/value/constant.jl:208
  convert(::Type{T}, ::ConcreteRArray{T, 0}) where T
   @ Reactant ~/.julia/packages/Reactant/MAkBF/src/ConcreteRArray.jl:73
  ...

Stacktrace:
  [1] setindex!(A::Vector{Float64}, x::Reactant.TracedRArray{Float64, 0}, i1::Int64)
    @ Base ./array.jl:1021
  [2] setindex!
    @ ./array.jl:1041 [inlined]
  [3] _typed_vcat!(a::Vector{Float64}, V::Tuple{Reactant.TracedRArray{Float64, 1}, Reactant.TracedRArray{Float64, 1}})
    @ Base ./abstractarray.jl:1640
  [4] _typed_vcat
    @ ./abstractarray.jl:1632 [inlined]
  [5] typed_vcat
    @ ./abstractarray.jl:1716 [inlined]
  [6] vcat
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/SparseArrays/src/sparsevector.jl:1239 [inlined]
  [7] f
    @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/Down/Reactant/test.jl:34 [inlined]
  [8] (::Tuple{})(none::Reactant.TracedRArray{Float64, 1})
    @ Base.Experimental ./<missing>:0
  [9] (::Reactant.var"#26#35"{typeof(f), Reactant.MLIR.IR.Block, Vector{}, Tuple{}})()
    @ Reactant ~/.julia/packages/Reactant/MAkBF/src/utils.jl:100
 [10] block!(f::Reactant.var"#26#35"{typeof(f), Reactant.MLIR.IR.Block, Vector{}, Tuple{}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/MAkBF/src/mlir/IR/Block.jl:201
 [11] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant ~/.julia/packages/Reactant/MAkBF/src/utils.jl:74
 [12] make_mlir_fn
    @ ~/.julia/packages/Reactant/MAkBF/src/utils.jl:24 [inlined]
 [13] #6
    @ ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:253 [inlined]
 [14] block!(f::Reactant.Compiler.var"#6#11"{typeof(f), Tuple{ConcreteRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/MAkBF/src/mlir/IR/Block.jl:201
 [15] #5
    @ ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:252 [inlined]
 [16] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, typeof(f), Tuple{}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/MAkBF/src/mlir/IR/Module.jl:93
 [17] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float64, 1}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:249
 [18] compile_mlir!
    @ ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:248 [inlined]
 [19] (::Reactant.Compiler.var"#30#32"{typeof(f), Tuple{ConcreteRArray{Float64, 1}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:576
 [20] context!(f::Reactant.Compiler.var"#30#32"{typeof(f), Tuple{ConcreteRArray{Float64, 1}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/MAkBF/src/mlir/IR/Context.jl:71
 [21] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float64, 1}}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:573
 [22] compile_xla
    @ ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:567 [inlined]
 [23] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 1}}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:600
 [24] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 1}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:599
 [25] macro expansion
    @ ~/.julia/packages/Reactant/MAkBF/src/Compiler.jl:360 [inlined]
 [26] top-level scope
    @ ~/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/Down/Reactant/test.jl:37
Some type information was truncated. Use `show(err)` to see complete types.
@wsmoses
Copy link
Member

wsmoses commented Oct 1, 2024 via email

@avik-pal avik-pal added the good first issue Good for newcomers label Oct 4, 2024
@avik-pal
Copy link
Collaborator

avik-pal commented Oct 4, 2024

An easier version if you don't want to deal with HLO is to simply call cat from vcat

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
3 participants