Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support to complex expr in compile_call_expr #351

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -52,6 +53,7 @@ CEnum = "0.5"
CUDA = "5.5"
Downloads = "1.6"
Enzyme = "0.13.22"
ExpressionExplorer = "1.1.0"
EnzymeCore = "0.8.8"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
Expand Down
93 changes: 74 additions & 19 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import ..Reactant:
ancestor,
TracedType

using ExpressionExplorer

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
Expand Down Expand Up @@ -493,6 +495,56 @@ macro jit(args...)
#! format: on
end

#create expression for more complex expression than a call
function wrapped_expression(expr::Expr)
css = ExpressionExplorer.compute_symbols_state(expr)
tracked_definitions = []
tracked_names = []
alter_expr = (e::Expr) -> begin
for (i, arg) in enumerate(e.args)
arg isa Expr && alter_expr(arg)
is_tracking_call(arg) || continue
name = gensym(:tracked)
push!(tracked_definitions, arg)
push!(tracked_names, name)
e.args[i] = name
end
end
alter_expr(expr)

free_args = collect(css.references)
function_args = tuple([free_args; tracked_definitions]...)
args = tuple([free_args; tracked_names]...)
fname = gensym(:F)

return (
Expr(:tuple, function_args...),
quote
($fname)($(args...)) = $expr
end,
quote
$(fname)
end,
)
end

function is_tracking_call(input)
Meta.isexpr(input, :call) || return false
function_name = (ExpressionExplorer.explore_funcdef!(input, ExpressionExplorer.ScopeState()))[1].parts[end]
return function_name in [:to_rarray, :ConcreteRNumber, :ConcreteRArray]
end

#check if an expression need to be wrap in a closure
function need_wrap(expr::Expr)::Bool
for arg in expr.args
arg isa Expr || continue
Meta.isexpr(arg, :.) && continue
is_tracking_call(arg) && continue
return true
end
return false
end

function compile_call_expr(mod, compiler, options, args...)
while length(args) > 1
option, args = args[1], args[2:end]
Expand All @@ -505,36 +557,39 @@ function compile_call_expr(mod, compiler, options, args...)
end
end
call = only(args)
f_symbol = gensym(:f)
args_symbol = gensym(:args)
compiled_symbol = gensym(:compiled)

if Meta.isexpr(call, :call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
fname = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
closure = ()
if call isa Expr && need_wrap(call)
(args_rhs, closure, fname) = wrapped_expression(call)
else
if Meta.isexpr(call, :call)
bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1])
fname = if bcast
quote
if isdefined(mod, $(Meta.quot(fname_full)))
$(fname_full)
else
Base.Broadcast.BroadcastFunction($(fname))
end
end
else
:($(fname))
end
args_rhs = Expr(:tuple, call.args[2:end]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
else
:($(fname))
error("Invalid function call: $(call)")
end
args_rhs = Expr(:tuple, call.args[2:end]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
else
error("Invalid function call: $(call)")
end

return quote
$(f_symbol) = $(fname)
$closure
$(args_symbol) = $(args_rhs)
$(compiled_symbol) = $(compiler)(
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
$(fname), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
)
end,
(; compiled=compiled_symbol, args=args_symbol)
Expand Down
10 changes: 10 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ f_var(args...) = sum(args)
@test @jit(f_var(x, y, z)) ≈ [6.6, 6.6, 6.6]
end

@testset "Complex expression" begin
x = Reactant.to_rarray(ones(3))
y = Reactant.ConcreteRNumber(3)
f(x) = x .+ 1
kw(x; a) = x * a
@test @jit(kw(x; a=y)) ≈ x * y
@test @jit(x + x - x + x * float(Base.pi) * 0) ≈ x
@test @jit(f(f(f(f(x)))) .+ Reactant.to_rarray(ones(3))) ≈ @allowscalar x .+ 5
end

function sumcos(x)
return sum(cos.(x))
end
Expand Down
Loading