diff --git a/src/Reactant.jl b/src/Reactant.jl index 1dbb846cd..4e4d1616a 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -294,7 +294,7 @@ using Enzyme if aT === nothing throw("Unhandled type $T") end - if datatype_fieldcount(aT) === nothing + if Base.datatype_fieldcount(aT) === nothing throw("Unhandled type $T") end end @@ -342,7 +342,7 @@ using Enzyme end if Val(T) ∈ seen - return seen[T] + return T end seen = (Val(T), seen...) @@ -720,10 +720,40 @@ function generate_jlfunc(concrete_result, client, mod, Nargs, linear_args, linea end) return end + if T <: Array + elems = Symbol[] + for (i, v) in enumerate(tocopy) + sym = Symbol(string(resname)*"_"*string(i)) + create_result(v, sym, (path...,i)) + push!(elems, sym) + end + push!(concrete_result_maker, quote + $resname = $(eltype(T))[$(elems...)] + end) + return + end if T <: Int || T <: AbstractFloat || T <: AbstractString || T <: Nothing push!(concrete_result_maker, :($resname = $tocopy)) return end + if T <: Symbol + push!(concrete_result_maker, :($resname = $(QuoteNode(tocopy)))) + return + end + if isstructtype(T) + elems = Symbol[] + nf = fieldcount(T) + for i in 1:nf + sym = Symbol(resname, :_, i) + create_result(getfield(tocopy, i), sym, (path..., i)) + push!(elems, sym) + end + push!(concrete_result_maker, quote + flds = Any[$(elems...)] + $resname = ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), $T, flds, $nf) + end) + return + end error("canot copy $T") end diff --git a/test/runtests.jl b/test/runtests.jl index 0636ec776..64ab7e583 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,3 +44,4 @@ include("layout.jl") include("basic.jl") include("bcast.jl") include("nn.jl") +include("struct.jl") diff --git a/test/struct.jl b/test/struct.jl new file mode 100644 index 000000000..24d480adc --- /dev/null +++ b/test/struct.jl @@ -0,0 +1,92 @@ +using Reactant +using Test + +# from bsc-quantic/Tenet.jl +struct MockTensor{T,N,A<:AbstractArray{T,N}} + data::A + inds::Vector{Symbol} +end + +MockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MockTensor{T,N,A}(data, inds) +Base.parent(t::MockTensor) = t.data + +Base.cos(x::MockTensor) = MockTensor(cos(parent(x)), x.inds) + +mutable struct MutableMockTensor{T,N,A<:AbstractArray{T,N}} + data::A + inds::Vector{Symbol} +end + +MutableMockTensor(data::A, inds) where {T,N,A<:AbstractArray{T,N}} = MutableMockTensor{T,N,A}(data, inds) +Base.parent(t::MutableMockTensor) = t.data + +Base.cos(x::MutableMockTensor) = MutableMockTensor(cos(parent(x)), x.inds) + +# modified from JuliaCollections/DataStructures.jl +# NOTE original uses abstract type instead of union, which is not supported +mutable struct MockLinkedList{T} + head::T + tail::Union{MockLinkedList{T},Nothing} +end + +function list(x::T...) where {T} + l = nothing + for i in Iterators.reverse(eachindex(x)) + l = MockLinkedList{T}(x[i], l) + end + return l +end + +Base.sum(x::MockLinkedList{T}) where {T} = sum(x.head) + (!isnothing(x.tail) ? sum(x.tail) : 0) + +@testset "Struct" begin + @testset "MockTensor" begin + @testset "immutable" begin + x = MockTensor(rand(4, 4), [:i, :j]) + x2 = MockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) + + f = Reactant.compile(cos, (x2,)) + y = f(x2) + + @test y isa MockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} + @test isapprox(parent(y), cos.(parent(x))) + @test x.inds == [:i, :j] + end + + @testset "mutable" begin + x = MutableMockTensor(rand(4, 4), [:i, :j]) + x2 = MutableMockTensor(Reactant.ConcreteRArray(parent(x)), x.inds) + + f = Reactant.compile(cos, (x2,)) + y = f(x2) + + @test y isa MutableMockTensor{Float64,2,Reactant.ConcreteRArray{Float64,(4, 4),2}} + @test isapprox(parent(y), cos.(parent(x))) + @test x.inds == [:i, :j] + end + end + + @testset "MockLinkedList" begin + x = [rand(2, 2) for _ in 1:2] + x2 = list(x...) + x3 = Reactant.make_tracer(IdDict(), x2, (), Reactant.ArrayToConcrete, nothing) + x4 = list(Reactant.ConcreteRArray.(x)...) + + + # TODO this should be able to run without problems, but crashes + @test_broken begin + f = Reactant.compile(identity, (x3,)) + isapprox(f(x3), x3) + end + + f3 = Reactant.compile(sum, (x3,)) + f4 = Reactant.compile(sum, (x4,)) + + y = sum(x2) + y3 = f3(x3) + y4 = f4(x4) + + @test isapprox(y, y3) + @test isapprox(y, y4) + end +end