From a17315c46fcdaeca2757a8855e6606f6404cb929 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Nov 2024 15:21:49 -0400 Subject: [PATCH] feat: specialize dispatches for faster concrete array generation (#213) * feat: specialize dispatches for faster concrete array generation * chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Tracing.jl | 24 +++++++++++++++++++++++- test/tracing.jl | 14 ++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index b6037d7c9..30a617cc3 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -563,7 +563,29 @@ end @inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=()) track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ()) + return to_rarray_internal(x, track_numbers) +end + +@inline function to_rarray_internal(@nospecialize(x), track_numbers::Tuple) return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers) end -to_rarray(x::ReactantPrimitive) = ConcreteRArray(x) +function to_rarray_internal(@nospecialize(::TracedRArray), ::Tuple) + return error("Cannot convert TracedRArray to ConcreteRArray") +end +@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), ::Tuple) = x +@inline function to_rarray_internal( + @nospecialize(x::AbstractArray{<:ReactantPrimitive}), ::Tuple +) + return ConcreteRArray(x) +end + +@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), ::Tuple) = x +@inline function to_rarray_internal( + @nospecialize(x::ReactantPrimitive), track_numbers::Tuple +) + for T in track_numbers + typeof(x) <: T && return ConcreteRNumber(x) + end + return x +end diff --git a/test/tracing.jl b/test/tracing.jl index d75a43598..88b1e7662 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -100,4 +100,18 @@ using Test end end end + + @testset "specialized dispatches" begin + @test @inferred Union{Float64,ConcreteRArray{Float64}} Reactant.to_rarray( + 1.0; track_numbers=(Number,) + ) isa ConcreteRNumber + @test @inferred Reactant.to_rarray(1.0) isa Float64 + @test @inferred Reactant.to_rarray(rand(3)) isa ConcreteRArray + + x_ra = Reactant.to_rarray(rand(3)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRArray + + x_ra = Reactant.to_rarray(1.0; track_numbers=(Number,)) + @test @inferred Reactant.to_rarray(x_ra) isa ConcreteRNumber + end end