Skip to content

Commit

Permalink
feat: partial NNlib.gather support + better indexing support (#252)
Browse files Browse the repository at this point in the history
* feat: unbreak NNlib.gather

* feat: use dynamic slicing

* chore: apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* feat: add an overload of Base.Tuple

* fix: ambiguity error

* feat: special case `gather!` for the most common cases

* feat: optimize the special case of indexing with unitranges

* test: dynamic slice test

* test: port NNlib gather tests over

* chore: apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix: mark length as Int64

* fix: use the C API for dimension numbers
  • Loading branch information
avik-pal authored Nov 10, 2024
1 parent df04f34 commit 9d666f8
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 38 deletions.
60 changes: 60 additions & 0 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,4 +296,64 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
)
end

# XXX: reevaluate this manual optimization once
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled
function NNlib.gather!(
dst::TracedRArray{T1,2},
src::AnyTracedRArray{T2,2},
idxs::Union{AbstractUnitRange{<:Number}},
) where {T1,T2}
dst.mlir_data = src[:, idxs].mlir_data
return dst
end

function NNlib.gather!(
dst::TracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number}
) where {T1,T2}
dims = NNlib.scatter_dims(src, dst, idxs)
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data

#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
MLIR.IR.context(),
Int64(1), Int64[0],
Int64(1), Int64[1],
Int64(0), Int64[],
Int64(0), Int64[],
Int64(1), Int64[1],
Int64(1)
)
#! format: on

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
src.mlir_data, idxs, slice_sizes; dimension_numbers
),
1,
)
dst.mlir_data = res
return dst
end

# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop
# instead of unrolling the loop (the case for AbstractArray can just use
# `stablehlo.gather`). See above for the special case implementation that is optimized.
function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractArray)
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \
This case is not optimized and will be slow." maxlog = 1
dims = NNlib.scatter_dims(src, dst, idxs)
colons = ntuple(Returns(Colon()), dims)
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
dst.mlir_data = res.mlir_data
return dst
end

end # module ReactantNNlibExt
74 changes: 50 additions & 24 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
end

function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
function Base.getindex(
a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
@warn(
"""Performing scalar indexing on task $(current_task()).
Invocation resulted in scalar indexing of a TracedRArray.
Expand All @@ -65,49 +67,59 @@ Such implementations *do not* execute on device, but very slowly on the CPU,
and require expensive copies and synchronization each time and therefore should be avoided."""
)

start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index]
slice_sizes = [Int64(1) for _ in index]

res1 = MLIR.IR.result(
MLIR.Dialects.stablehlo.slice(
a.mlir_data;
start_indices=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in index]),
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(i) for i in index]),
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in index]),
),
1,
MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1
)
res2 = MLIR.IR.result(
MLIR.Dialects.stablehlo.reshape(
res1; result_0=MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(res1)))
),
1,
)

return TracedRNumber{T}((), res2)
end

function Base.getindex(a::TracedRArray{T,0}) where {T}
return TracedRNumber{T}((), a.mlir_data)
end

# XXX: We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
return i
end

foreach(indices) do idxs
idxs isa Number && return nothing
contiguous = all(isone, diff(idxs))
# XXX: We want to throw error even for dynamic indexing
if typeof(a) <: Bool
contiguous || error("non-contiguous indexing is not supported")
end
end

start_indices = map(indices) do i
return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data
end
slice_sizes = [Int64(length(i)) for i in indices]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.slice(
a.mlir_data;
start_indices=MLIR.IR.DenseArrayAttribute([
Int64(first(i) - 1) for i in indices
]),
limit_indices=MLIR.IR.DenseArrayAttribute([Int64(last(i)) for i in indices]),
strides=MLIR.IR.DenseArrayAttribute([Int64(1) for i in indices]),
),
1,
MLIR.Dialects.stablehlo.dynamic_slice(a.mlir_data, start_indices; slice_sizes), 1
)

x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
ddims = findall(x -> x isa Integer, indices)
!isempty(ddims) && return dropdims(x; dims=Tuple(ddims))
ddims = findall(Base.Fix2(isa, Integer), indices)
isempty(ddims) || return dropdims(x; dims=Tuple(ddims))
return x
end

# Prevent ambiguity
function Base.getindex(a::WrappedTracedRArray, index::Int...)
function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...)
return getindex(ancestor(a), get_ancestor_indices(a, index...)...)
end

Expand All @@ -116,7 +128,9 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
end

function Base.setindex!(
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N}
a::TracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i)
Expand All @@ -138,13 +152,17 @@ function Base.setindex!(
end

function Base.setindex!(
a::AnyTracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N}
a::AnyTracedRArray{T,N},
v,
indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N},
) where {T,N}
ancestor_indices = get_ancestor_indices(a, indices...)
setindex!(ancestor(a), v, ancestor_indices...)
return a
end

Base.Tuple(x::TracedRArray) = ntuple(Base.Fix1(Base.getindex, x), length(x))

Base.size(x::TracedRArray) = x.shape

Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
Expand Down Expand Up @@ -699,7 +717,7 @@ end

function broadcast_to_size(arg::T, rsize) where {T<:Number}
attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize)))
return arg = TracedRArray{T,length(rsize)}(
return TracedRArray{T,length(rsize)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize
)
end
Expand All @@ -711,6 +729,11 @@ function broadcast_to_size(arg::TracedRNumber, rsize)
)
end

function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T}
arg = materialize_traced_array(arg)
return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize)
end

function broadcast_to_size(arg::AnyTracedRArray, rsize)
arg = materialize_traced_array(arg)
size(arg) == rsize && return arg
Expand Down Expand Up @@ -856,3 +879,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN
return x
end
end

Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x)
Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x)
32 changes: 18 additions & 14 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,24 @@ function Base.ifelse(
)
end

function Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:!(x::TracedRNumber{Bool})
return TracedRNumber{Bool}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1)
)
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
@eval begin
function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return TracedRNumber{promote_type(eltype(x), eltype(y))}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.and(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return TracedRNumber{promote_type(eltype(x), eltype(y))}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.or(x.mlir_data, y.mlir_data), 1)
)
end
function Base.:!(x::TracedRNumber{<:$(T1)})
return TracedRNumber{eltype(x)}(
(), MLIR.IR.result(MLIR.Dialects.stablehlo.not(x.mlir_data), 1)
)
end
end
end

function Base.literal_pow(
Expand Down
11 changes: 11 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,14 @@ end
@test minimum(y) 0.0
@test x_ra x
end

@testset "dynamic indexing" begin
x = randn(5, 3)
x_ra = Reactant.to_rarray(x)

idx = [1, 2, 3]
idx_ra = Reactant.to_rarray(idx)

y = @jit(getindex(x_ra, idx_ra, :))
@test y x[idx, :]
end
Loading

1 comment on commit 9d666f8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 9d666f8 Previous: 3ba7c3e Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 6791483981 ns 5787425685 ns 1.17
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5758935576 ns 5292258390 ns 1.09
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5991010944 ns 6086056532 ns 0.98
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7226896499 ns 7587601119 ns 0.95
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 35602806226 ns 28087750784 ns 1.27
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1554984111 ns 1563822331 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1544583024 ns 1543677512 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1548949983 ns 1553822136 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3305789982 ns 3309603029 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2506900881 ns 3236551447 ns 0.77
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2135647085 ns 2198150190 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2114392785 ns 2155687426 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2152785430 ns 2192886728 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3920703119 ns 3908194881 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 6189728956.5 ns 5993416352 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1457078887 ns 1406808783.5 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1444716715 ns 1407299141 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1436557810 ns 1410969730 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3216204610 ns 3156311368 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1056965356.5 ns 1099155376.5 ns 0.96
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1742579811 ns 1727787162 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1724456565 ns 1727804980 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1703597363 ns 1711663111 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3527083552 ns 3460051766 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3168974882.5 ns 3010659432 ns 1.05
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2194009586 ns 2148427239 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2185020760 ns 2170426380 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2145844912 ns 2187259107 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3888344576 ns 3958804601 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5917571640 ns 6647100753 ns 0.89
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2952461976 ns 3146044029 ns 0.94
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3000541626 ns 3146912971 ns 0.95
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3037769031 ns 3047329260 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 5003447240 ns 4862728550 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 16939295133 ns 12794226734 ns 1.32
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3182179123 ns 3132478421 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3260047692 ns 3179953038 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3279955673 ns 3185074336 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5217808864 ns 5092564084 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 13308548315 ns 12253319305 ns 1.09
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1848723326 ns 1855345054 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1856209821 ns 1849809131 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1839420405 ns 1855337197 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3618461240 ns 3604644289 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3120647639.5 ns 5868629461.5 ns 0.53

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.