From 0e764deba8dd63fdbdb76c946df45a7c695eb6c6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 3 Jan 2025 15:55:14 -0500 Subject: [PATCH] Fix offsetarrays support (#464) * Fix offsetarrays support * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add test file --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 8 +++++--- ext/ReactantOffsetArraysExt.jl | 13 +++++++++++++ src/Tracing.jl | 2 +- test/Project.toml | 1 + test/integration/offsetarrays.jl | 19 +++++++++++++++++++ test/runtests.jl | 1 + 6 files changed, 40 insertions(+), 4 deletions(-) create mode 100644 ext/ReactantOffsetArraysExt.jl create mode 100644 test/integration/offsetarrays.jl diff --git a/Project.toml b/Project.toml index d5e57ef82..cbcac208d 100644 --- a/Project.toml +++ b/Project.toml @@ -24,20 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random123 = "74087812-796a-5b5d-8853-05524746bad3" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources.ReactantCore] -path = "lib/ReactantCore" - [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" +ReactantOffsetArraysExt = "OffsetArrays" ReactantPythonCallExt = "PythonCall" ReactantRandom123Ext = "Random123" ReactantSpecialFunctionsExt = "SpecialFunctions" @@ -75,3 +74,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + +[sources.ReactantCore] +path = "lib/ReactantCore" diff --git a/ext/ReactantOffsetArraysExt.jl b/ext/ReactantOffsetArraysExt.jl new file mode 100644 index 000000000..3ba82e9ac --- /dev/null +++ b/ext/ReactantOffsetArraysExt.jl @@ -0,0 +1,13 @@ +module ReactantOffsetArraysExt + +using OffsetArrays: OffsetArray +using Reactant: Reactant, MLIR, Ops, TracedRArray + +function Reactant.traced_type( + ::Type{<:OffsetArray{<:Any,N,T}}, seen::ST, ::Val{mode}, track_numbers +) where {T,N,ST,mode} + T2 = Reactant.traced_type(T, seen, Val(mode), track_numbers) + return OffsetArray{eltype(T2),N,T2} +end + +end diff --git a/src/Tracing.jl b/src/Tracing.jl index 2834c0f2c..797c4918d 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -705,7 +705,7 @@ function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple) end @inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x @inline function to_rarray_internal( - @nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple + @nospecialize(x::Array{<:ReactantPrimitive}), ::Tuple ) return ConcreteRArray(x) end diff --git a/test/Project.toml b/test/Project.toml index 90b315b86..917ba8e22 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,6 +15,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" diff --git a/test/integration/offsetarrays.jl b/test/integration/offsetarrays.jl new file mode 100644 index 000000000..a8bb4f899 --- /dev/null +++ b/test/integration/offsetarrays.jl @@ -0,0 +1,19 @@ +using Reactant +using Test +using OffsetArrays + +function scalar_index(x) + @allowscalar getindex(x, -1, 0) +end +@testset "OffsetArrays" begin + A = Float64.(reshape(1:15, 3, 5)) + OA = OffsetArray(A, -1:1, 0:4) + rOA = Reactant.to_rarray(OA) + + oval = scalar_index(OA) + cval = scalar_index(rOA) + @test cval ≈ oval + + tval = @jit scalar_index(rOA) + @test tval ≈ oval +end diff --git a/test/runtests.jl b/test/runtests.jl index f0e9ea1f4..0f019a12b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) # Temporarily disabled as minutia are debugged # @safetestset "CUDA" include("integration/cuda.jl") @safetestset "Linear Algebra" include("integration/linear_algebra.jl") + @safetestset "OffsetArrays" include("integration/offsetarrays.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") @safetestset "SpecialFunctions" include("integration/special_functions.jl") @safetestset "Random" include("integration/random.jl")