Skip to content

Commit

Permalink
Fix offsetarrays support (#464)
Browse files Browse the repository at this point in the history
* Fix offsetarrays support

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add test file

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Jan 3, 2025
1 parent ce20b3c commit 0e764de
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 4 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,19 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[sources.ReactantCore]
path = "lib/ReactantCore"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantOffsetArraysExt = "OffsetArrays"
ReactantPythonCallExt = "PythonCall"
ReactantRandom123Ext = "Random123"
ReactantSpecialFunctionsExt = "SpecialFunctions"
Expand Down Expand Up @@ -75,3 +74,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[sources.ReactantCore]
path = "lib/ReactantCore"
13 changes: 13 additions & 0 deletions ext/ReactantOffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module ReactantOffsetArraysExt

using OffsetArrays: OffsetArray
using Reactant: Reactant, MLIR, Ops, TracedRArray

function Reactant.traced_type(
::Type{<:OffsetArray{<:Any,N,T}}, seen::ST, ::Val{mode}, track_numbers
) where {T,N,ST,mode}
T2 = Reactant.traced_type(T, seen, Val(mode), track_numbers)
return OffsetArray{eltype(T2),N,T2}
end

end
2 changes: 1 addition & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple)
end
@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x
@inline function to_rarray_internal(
@nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple
@nospecialize(x::Array{<:ReactantPrimitive}), ::Tuple
)
return ConcreteRArray(x)
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Expand Down
19 changes: 19 additions & 0 deletions test/integration/offsetarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Reactant
using Test
using OffsetArrays

function scalar_index(x)
@allowscalar getindex(x, -1, 0)
end
@testset "OffsetArrays" begin
A = Float64.(reshape(1:15, 3, 5))
OA = OffsetArray(A, -1:1, 0:4)
rOA = Reactant.to_rarray(OA)

oval = scalar_index(OA)
cval = scalar_index(rOA)
@test cval oval

tval = @jit scalar_index(rOA)
@test tval oval
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
# Temporarily disabled as minutia are debugged
# @safetestset "CUDA" include("integration/cuda.jl")
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "OffsetArrays" include("integration/offsetarrays.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
@safetestset "SpecialFunctions" include("integration/special_functions.jl")
@safetestset "Random" include("integration/random.jl")
Expand Down

0 comments on commit 0e764de

Please sign in to comment.