From 38e10fe06b37cb71865994844ba958e82fc40a56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Tue, 10 Dec 2024 03:09:35 +0100 Subject: [PATCH 1/5] add support to complex expr in compile_call_expr --- Project.toml | 6 ++-- src/Compiler.jl | 75 ++++++++++++++++++++++++++++++++++++------------- test/basic.jl | 7 +++++ 3 files changed, 67 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 648784816..beee67ad1 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -25,8 +26,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" -[sources.ReactantCore] -path = "lib/ReactantCore" +[sources] +ReactantCore = {path = "lib/ReactantCore"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" @@ -43,6 +44,7 @@ CEnum = "0.4, 0.5" Downloads = "1.6" Enzyme = "0.13.21" EnzymeCore = "0.8.6, 0.8.7, 0.8.8" +ExpressionExplorer = "1.1.0" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" NNlib = "0.9.24" diff --git a/src/Compiler.jl b/src/Compiler.jl index 586f33b05..b9674c61a 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -14,6 +14,8 @@ import ..Reactant: append_path, TracedType +using ExpressionExplorer + @inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) @inline traced_getfield( @nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field @@ -432,6 +434,38 @@ macro jit(args...) #! format: on end +is_a_module(s::Symbol)::Bool = begin + isdefined(@__MODULE__, s) && getproperty(@__MODULE__, s) isa Module +end + +#create expression for more complex expression than a call +function wrapped_expression(expr::Expr) + args = ExpressionExplorer.compute_symbols_state(expr).references + args = filter(!is_a_module, args) + args = tuple(collect(args)...) + fname = gensym(:F) + + return ( + Expr(:tuple, args...), + quote + ($fname)($(args...)) = $expr + end, + quote + $(fname) + end, + ) +end + +#check if an expression need to be wrap in a closure +function need_wrap(expr::Expr)::Bool + for arg in expr.args + arg isa Expr || continue + Meta.isexpr(arg, :.) && continue + return true + end + return false +end + function compile_call_expr(mod, compiler, options, args...) while length(args) > 1 option, args = args[1], args[2:end] @@ -444,36 +478,39 @@ function compile_call_expr(mod, compiler, options, args...) end end call = only(args) - f_symbol = gensym(:f) args_symbol = gensym(:args) compiled_symbol = gensym(:compiled) - - if Meta.isexpr(call, :call) - bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1]) - fname = if bcast - quote - if isdefined(mod, $(Meta.quot(fname_full))) - $(fname_full) - else - Base.Broadcast.BroadcastFunction($(fname)) + closure = () + if call isa Expr && need_wrap(call) + (args_rhs, closure, fname) = wrapped_expression(call) + else + if Meta.isexpr(call, :call) + bcast, fname, fname_full = correct_maybe_bcast_call(call.args[1]) + fname = if bcast + quote + if isdefined(mod, $(Meta.quot(fname_full))) + $(fname_full) + else + Base.Broadcast.BroadcastFunction($(fname)) + end end + else + :($(fname)) end + args_rhs = Expr(:tuple, call.args[2:end]...) + elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple) + fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1]))) + args_rhs = only(call.args[2:end]) else - :($(fname)) + error("Invalid function call: $(call)") end - args_rhs = Expr(:tuple, call.args[2:end]...) - elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple) - fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1]))) - args_rhs = only(call.args[2:end]) - else - error("Invalid function call: $(call)") end return quote - $(f_symbol) = $(fname) + $closure $(args_symbol) = $(args_rhs) $(compiled_symbol) = $(compiler)( - $(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...) + $(fname), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...) ) end, (; compiled=compiled_symbol, args=args_symbol) diff --git a/test/basic.jl b/test/basic.jl index 6aac778f2..b1b964416 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -101,6 +101,13 @@ f_var(args...) = sum(args) @test @jit(f_var(x, y, z)) ≈ [6.6, 6.6, 6.6] end +@testset "Complex expression" begin + x = Reactant.to_rarray(ones(3)) + f(x) = x .+ 1 + @test @jit(x + x - x + x * float(Base.pi) * 0) ≈ x + @test @jit(f(f(f(f(x))))) ≈ @allowscalar x .+ 4 +end + function sumcos(x) return sum(cos.(x)) end From 47259575b520c37cef81c338fe3aa9d1e199e078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 11 Dec 2024 15:22:39 +0100 Subject: [PATCH 2/5] restore support for '@jit foo(Reactant.to_rarray(rand(2)))' and add '@jit foo(foo(Reactant.to_rarray(rand(2))))' --- src/Compiler.jl | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index b9674c61a..388357583 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -440,13 +440,28 @@ end #create expression for more complex expression than a call function wrapped_expression(expr::Expr) - args = ExpressionExplorer.compute_symbols_state(expr).references - args = filter(!is_a_module, args) - args = tuple(collect(args)...) + css = ExpressionExplorer.compute_symbols_state(expr) + tracked_definitions = [] + tracked_names = [] + alter_expr = (e::Expr) -> begin + for (i, arg) in enumerate(e.args) + arg isa Expr && alter_expr(arg) + is_tracking_call(arg) || continue + name = gensym(:tracked) + push!(tracked_definitions, arg) + push!(tracked_names, name) + e.args[i] = name + end + end + alter_expr(expr) + + free_args = collect(css.references) + function_args = tuple([free_args; tracked_definitions]...) + args = tuple([free_args; tracked_names]...) fname = gensym(:F) return ( - Expr(:tuple, args...), + Expr(:tuple, function_args...), quote ($fname)($(args...)) = $expr end, @@ -456,11 +471,18 @@ function wrapped_expression(expr::Expr) ) end +function is_tracking_call(input) + Meta.isexpr(input, :call) || return false + function_name = (ExpressionExplorer.explore_funcdef!(input, ExpressionExplorer.ScopeState()))[1].parts[end] + function_name in [:to_rarray, :ConcreteRNumber] +end + #check if an expression need to be wrap in a closure function need_wrap(expr::Expr)::Bool for arg in expr.args arg isa Expr || continue Meta.isexpr(arg, :.) && continue + is_tracking_call(arg) && continue return true end return false From ee7a540c8e2721253fb0c7f16d8c5b43f0a5f84a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 11 Dec 2024 15:54:49 +0100 Subject: [PATCH 3/5] cleanup --- src/Compiler.jl | 6 +----- test/basic.jl | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 388357583..9d2726749 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -434,10 +434,6 @@ macro jit(args...) #! format: on end -is_a_module(s::Symbol)::Bool = begin - isdefined(@__MODULE__, s) && getproperty(@__MODULE__, s) isa Module -end - #create expression for more complex expression than a call function wrapped_expression(expr::Expr) css = ExpressionExplorer.compute_symbols_state(expr) @@ -474,7 +470,7 @@ end function is_tracking_call(input) Meta.isexpr(input, :call) || return false function_name = (ExpressionExplorer.explore_funcdef!(input, ExpressionExplorer.ScopeState()))[1].parts[end] - function_name in [:to_rarray, :ConcreteRNumber] + function_name in [:to_rarray, :ConcreteRNumber, :ConcreteRArray] end #check if an expression need to be wrap in a closure diff --git a/test/basic.jl b/test/basic.jl index 88c9d3ceb..1d80e7887 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -105,7 +105,7 @@ end x = Reactant.to_rarray(ones(3)) f(x) = x .+ 1 @test @jit(x + x - x + x * float(Base.pi) * 0) ≈ x - @test @jit(f(f(f(f(x))))) ≈ @allowscalar x .+ 4 + @test @jit(f(f(f(f(x)))) .+ Reactant.to_rarray(ones(3))) ≈ @allowscalar x .+ 5 end function sumcos(x) From ad16dd5574e5f511a48b08de1bc6e9f490e1c5af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sat, 14 Dec 2024 21:34:06 +0100 Subject: [PATCH 4/5] add support to kw call --- test/basic.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index eee44bcb3..8b050aefc 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -103,7 +103,10 @@ end @testset "Complex expression" begin x = Reactant.to_rarray(ones(3)) + y = Reactant.ConcreteRNumber(3) f(x) = x .+ 1 + kw(x; a) = x * a + @test @jit(kw(x; a = y)) ≈ x * y @test @jit(x + x - x + x * float(Base.pi) * 0) ≈ x @test @jit(f(f(f(f(x)))) .+ Reactant.to_rarray(ones(3))) ≈ @allowscalar x .+ 5 end From e90a00c0b11d189b88a47019c10571e3bfedb3c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 1 Jan 2025 02:46:44 +0100 Subject: [PATCH 5/5] format --- src/Compiler.jl | 2 +- test/basic.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index f8e5e30ec..5a25bdf68 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -531,7 +531,7 @@ end function is_tracking_call(input) Meta.isexpr(input, :call) || return false function_name = (ExpressionExplorer.explore_funcdef!(input, ExpressionExplorer.ScopeState()))[1].parts[end] - function_name in [:to_rarray, :ConcreteRNumber, :ConcreteRArray] + return function_name in [:to_rarray, :ConcreteRNumber, :ConcreteRArray] end #check if an expression need to be wrap in a closure diff --git a/test/basic.jl b/test/basic.jl index 082f1544f..f30847d16 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -106,7 +106,7 @@ end y = Reactant.ConcreteRNumber(3) f(x) = x .+ 1 kw(x; a) = x * a - @test @jit(kw(x; a = y)) ≈ x * y + @test @jit(kw(x; a=y)) ≈ x * y @test @jit(x + x - x + x * float(Base.pi) * 0) ≈ x @test @jit(f(f(f(f(x)))) .+ Reactant.to_rarray(ones(3))) ≈ @allowscalar x .+ 5 end