From 65e9976ac1fc964d78b85298b0e4ca2a1cf74bb8 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 14 Dec 2024 14:05:03 -0600 Subject: [PATCH] Interp2 (#365) * WIP: kernels * more files * fix * wip * wqtmp * wip * inc * continuing * wip * more work * inf rec * fix * overload working * continuing * continuing * push * fix `call_with_reactant_generator` for Julia 1.11 (#359) * conversion * continuing * Cleanup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Delete test/cuda.jl * fixup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix apply * indep of change * minor fix in name * Update utils.jl * Interp take 2 * continuing adentures * delcode * fix * tmp * make * fix * cleanup * continuing * more working * further simplify * fx * more improvements * minus show * less prints * even fewer * confusion * tmp * force clean * force oc * clean * Rewrite * fixup * fix * fix * fix * fixup * fix * wip * safe prints * fix * fix * stackoverflow * cleanup * dyindex * rt * continue * clean * fix * fix * fix * fix * fixup * fix * fix * capture oc * compile perf * v1.11 fix * other way 'round * formatting --------- Co-authored-by: William Moses Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: jumerckx --- Project.toml | 4 +- deps/ReactantExtra/API.cpp | 10 + deps/ReactantExtra/BUILD | 2 + ext/ReactantNNlibExt.jl | 27 +- ext/ReactantStatisticsExt.jl | 3 +- ext/ReactantYaoBlocksExt.jl | 7 +- lib/ReactantCore/Project.toml | 2 +- lib/ReactantCore/src/ReactantCore.jl | 6 +- src/Compiler.jl | 9 +- src/ConcreteRArray.jl | 10 +- src/ControlFlow.jl | 16 +- src/Interpreter.jl | 302 +++-------- src/Ops.jl | 191 ++++--- src/Reactant.jl | 93 ++-- src/TracedRArray.jl | 413 ++++----------- src/TracedRNumber.jl | 101 ++-- src/TracedUtils.jl | 530 +++++++++++++++++++ src/Tracing.jl | 8 +- src/linear_algebra.jl | 42 +- src/utils.jl | 758 +++++++++++++++++++-------- test/basic.jl | 2 +- test/complex.jl | 2 +- 22 files changed, 1578 insertions(+), 960 deletions(-) create mode 100644 src/TracedUtils.jl diff --git a/Project.toml b/Project.toml index ac37e645a..9af7dafef 100644 --- a/Project.toml +++ b/Project.toml @@ -41,14 +41,14 @@ Adapt = "4" ArrayInterface = "7.10" CEnum = "0.4, 0.5" Downloads = "1.6" -Enzyme = "0.13.21" +Enzyme = "0.13.22" EnzymeCore = "0.8.8" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" -ReactantCore = "0.1.2" +ReactantCore = "0.1.3" Reactant_jll = "0.0.26" Scratch = "1.2" Statistics = "1.10" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8614bdcd9..f93b32ea4 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -376,6 +376,16 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) { return wrap(res); } +#include "llvm/IRReader/IRReader.h" +extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) { + LLVMContext Context; + SMDiagnostic Err; + auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context); + mlir::MLIRContext &context = *unwrap(cctx); + auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release(); + return wrap(res); +} + /* Note that this */ extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) { diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index e7157d89c..c718304bd 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -450,6 +450,8 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:Transforms", + + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "@llvm-project//llvm:AArch64AsmParser", "@llvm-project//llvm:AArch64CodeGen", diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 8bfa5de02..f85bd1d84 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -2,16 +2,10 @@ module ReactantNNlibExt using NNlib using GPUArraysCore: @allowscalar -using Reactant: - Reactant, - Ops, - TracedRArray, - AnyTracedRArray, - materialize_traced_array, - MLIR, - TracedRNumber, - get_mlir_data, - set_mlir_data! +using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber + +using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data! + using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu @@ -238,9 +232,9 @@ function NNlib.batched_mul!( if size(x, 3) != size(y, 3) B = max(size(x, 3), size(y, 3)) if size(x, 3) == 1 - x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) elseif size(y, 3) == 1 - y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) + y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) end end @@ -250,9 +244,9 @@ function NNlib.batched_mul!( if size(x, 1) != size(y, 1) B = max(size(x, 1), size(y, 1)) if size(x, 1) == 1 - x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) + x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) elseif size(y, 1) == 1 - y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) + y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) end end @@ -270,7 +264,7 @@ end function NNlib.pad_constant( x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value ) where {T,N} - value = Reactant.promote_to(TracedRNumber{T}, value) + value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value) low = [i[1] for i in pad] high = [i[2] for i in pad] interior = [0 for i in pad] @@ -329,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr start_sizes = ntuple(i -> size(src, i), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] - res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) + res isa TracedRNumber && + (res = Reactant.TracedUtils.broadcast_to_size(res, (1,))) return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl index f733511af..40db81a8e 100644 --- a/ext/ReactantStatisticsExt.jl +++ b/ext/ReactantStatisticsExt.jl @@ -1,6 +1,7 @@ module ReactantStatisticsExt -using Reactant: AnyTracedRArray, materialize_traced_array +using Reactant: AnyTracedRArray +using Reactant.TracedUtils: materialize_traced_array using Statistics: Statistics function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N} diff --git a/ext/ReactantYaoBlocksExt.jl b/ext/ReactantYaoBlocksExt.jl index 2542d8a08..cc16e51be 100644 --- a/ext/ReactantYaoBlocksExt.jl +++ b/ext/ReactantYaoBlocksExt.jl @@ -1,12 +1,13 @@ module ReactantYaoBlocksExt using Reactant +using Reactant.TracedUtils: broadcast_to_size using YaoBlocks function YaoBlocks.mat( ::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate} ) where {D,T,S} - M = Reactant.broadcast_to_size(zero(T), (2, 2)) + M = broadcast_to_size(zero(T), (2, 2)) c = cos(R.theta / 2) s = -im * sin(R.theta / 2) M[1, 1] = c @@ -19,7 +20,7 @@ end function YaoBlocks.mat( ::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate} ) where {D,T,S} - M = Reactant.broadcast_to_size(zero(T), (2, 2)) + M = broadcast_to_size(zero(T), (2, 2)) c = cos(R.theta / 2) s = sin(R.theta / 2) M[1, 1] = c @@ -32,7 +33,7 @@ end function YaoBlocks.mat( ::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate} ) where {D,T,S} - M = Reactant.broadcast_to_size(zero(T), (2, 2)) + M = broadcast_to_size(zero(T), (2, 2)) x = exp(im * R.theta / 2) M[1, 1] = conj(x) M[2, 2] = x diff --git a/lib/ReactantCore/Project.toml b/lib/ReactantCore/Project.toml index a11f5c66a..bec50b45e 100644 --- a/lib/ReactantCore/Project.toml +++ b/lib/ReactantCore/Project.toml @@ -1,7 +1,7 @@ name = "ReactantCore" uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.1.2" +version = "0.1.3" [deps] ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43" diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 32c663fce..f99d6cab9 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -153,7 +153,9 @@ function trace_for(mod, expr) all_syms = Expr(:tuple, counter, external_syms...) args_init = Expr( - :tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms... + :tuple, + :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), + external_syms..., ) reactant_code_block = quote @@ -161,7 +163,7 @@ function trace_for(mod, expr) cond_fn = $(all_syms) -> begin local num_iters = div($limit - $start, $step, RoundDown) - local num_iters = Reactant.promote_to( + local num_iters = Reactant.TracedUtils.promote_to( Reactant.TracedRNumber{Int64}, num_iters ) $counter < num_iters + 1 diff --git a/src/Compiler.jl b/src/Compiler.jl index 586f33b05..5f7158d82 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -292,7 +292,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = MLIR.IR.mmodule!(mod) do MLIR.IR.block!(MLIR.IR.body(mod)) do - return Reactant.make_mlir_fn(f, args, (), "main", true) + return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) end end @@ -779,6 +779,13 @@ function compile(f, args; client=nothing, optimize=true, sync=false) return register_thunk(fname, body) end +# Compiling within a compile should return simply the original function +Reactant.@reactant_override function Reactant.Compiler.compile( + f, args; client=nothing, optimize=true, sync=false +) + return f +end + # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index dac67bf69..e9d9c02d7 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -99,7 +99,7 @@ end function Base.convert( ::Type{T}, X::WrappedConcreteRArray{ElType,N} ) where {T<:Array,ElType,N} - fn = compile(materialize_traced_array, (X,)) + fn = compile(TracedUtils.materialize_traced_array, (X,)) return convert(Array, fn(X)) end Base.Array(x::AnyConcreteRArray) = convert(Array, x) @@ -345,3 +345,11 @@ end buffer_on_cpu(::Any) = true buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer) + +function Ops.constant(x::ConcreteRArray; kwargs...) + return Ops.constant(Base.convert(Array, x); kwargs...) +end + +function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T} + return Ops.constant(Base.convert(T, x); kwargs...) +end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 3b30c4cb6..0e0c00195 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -1,7 +1,7 @@ function ReactantCore.traced_if( cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args ) where {TFn,FFn} - (_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.make_mlir_fn( + (_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn( true_fn, args, (), @@ -12,7 +12,7 @@ function ReactantCore.traced_if( construct_function_without_args=true, ) - (_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.make_mlir_fn( + (_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn( false_fn, args, (), @@ -36,16 +36,16 @@ function ReactantCore.traced_if( returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.") elseif tr isa MissingTracedValue push!(result_types, MLIR.IR.type(fr.mlir_data)) - push!(linear_results, new_traced_value(false_linear_results[i])) + push!(linear_results, TracedUtils.new_traced_value(false_linear_results[i])) push!(true_block_insertions, (i => linear_results[end])) else push!(result_types, MLIR.IR.type(tr.mlir_data)) - push!(linear_results, new_traced_value(true_linear_results[i])) + push!(linear_results, TracedUtils.new_traced_value(true_linear_results[i])) push!(false_block_insertions, (i => linear_results[end])) end else push!(result_types, MLIR.IR.type(tr.mlir_data)) - push!(linear_results, new_traced_value(tr)) + push!(linear_results, TracedUtils.new_traced_value(tr)) end end @@ -82,13 +82,13 @@ function ReactantCore.traced_while( # We promote all incoming args (is there a better way to do this?) traced_args = [ if v isa Number && !(v isa TracedType) - Reactant.promote_to(TracedRNumber{typeof(v)}, v) + Reactant.TracedUtils.promote_to(TracedRNumber{typeof(v)}, v) else v end for v in args ] - (_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn( + (_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn( cond_fn, traced_args, (), @@ -99,7 +99,7 @@ function ReactantCore.traced_while( do_transpose=false, ) - (_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn( + (_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn( body_fn, traced_args, (), diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 8a039e17a..72e27c5d8 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -21,6 +21,14 @@ import Core.Compiler: mapany, MethodResultPure +Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) + +function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def) + return Base.Experimental.var"@overlay"( + __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def + ) +end + function set_reactant_abi( interp, @nospecialize(f), @@ -31,71 +39,11 @@ function set_reactant_abi( ) (; fargs, argtypes) = arginfo - if ( - (f === Enzyme.autodiff) || - (f === Enzyme.autodiff_deferred) || - (f === Enzyme.gradient) || - (f === Enzyme.jacobian) - ) && (length(argtypes) >= 2) - if widenconst(argtypes[2]) <: Enzyme.Mode - newmode = Enzyme.set_abi(widenconst(argtypes[2]), ReactantABI) - if newmode != widenconst(argtypes[2]) - newmodev = newmode() - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [fargs[1], :($(newmodev)), fargs[3:end]...] - end, - [argtypes[1], Core.Const(newmodev), argtypes[3:end]...], - ) - return abstract_call_known(interp, f, arginfo2, si, sv, max_methods) - end - end - end - - if length(argtypes) >= 5 && - f === Core.kwcall && - ( - widenconst(argtypes[3]) == typeof(Enzyme.gradient) || - widenconst(argtypes[3]) == typeof(Enzyme.jacobian) - ) && - widenconst(argtypes[4]) <: Enzyme.Mode - newmode = Enzyme.set_abi(widenconst(argtypes[4]), ReactantABI) - if newmode != widenconst(argtypes[4]) - newmodev = newmode() - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [fargs[1:3]..., :($(newmodev)), fargs[5:end]...] - end, - [argtypes[1:3]..., Core.Const(newmodev), argtypes[5:end]...], - ) - return abstract_call_known(interp, f, arginfo2, si, sv, max_methods) - end - end - - if length(argtypes) >= 5 && - methods(f)[1].module == Enzyme && - widenconst(argtypes[5]) <: Enzyme.Mode && - ( - widenconst(argtypes[4]) == typeof(Enzyme.gradient) || - widenconst(argtypes[4]) == typeof(Enzyme.jacobian) - ) - newmode = Enzyme.set_abi(widenconst(argtypes[5]), ReactantABI) - if newmode != widenconst(argtypes[5]) - newmodev = newmode() - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [fargs[1:4]..., :($(newmodev)), fargs[6:end]...] - end, - [argtypes[1:4]..., Core.Const(newmodev), argtypes[6:end]...], - ) - return abstract_call_known(interp, f, arginfo2, si, sv, max_methods) - end + # Improve inference by considering call_with_reactant as having the same results as + # the original call + if f === Reactant.call_with_reactant + arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) + return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) end return Base.@invoke abstract_call_known( @@ -108,15 +56,13 @@ function set_reactant_abi( ) end -function set_reactant_abi end - @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end function ReactantInterpreter(; world::UInt=Base.get_world_counter()) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), - nothing, #=mt=# + REACTANT_METHOD_TABLE, world, true, #=forward_rules=# true, #=reverse_rules=# @@ -132,7 +78,7 @@ else ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, - nothing, #=mt=# + REACTANT_METHOD_TABLE, world, true, #=forward_rules=# true, #=forward_rules=# @@ -196,38 +142,30 @@ const enzyme_constnoneed = 5 enzyme_outnoneed end -function push_val!(ad_inputs, x, path) - for p in path - x = traced_getfield(x, p) - end - x = x.mlir_data - return push!(ad_inputs, x) -end - function push_acts!(ad_inputs, x::Const, path, reverse) - return push_val!(ad_inputs, x.val, path) + return TracedUtils.push_val!(ad_inputs, x.val, path) end function push_acts!(ad_inputs, x::Active, path, reverse) - return push_val!(ad_inputs, x.val, path) + return TracedUtils.push_val!(ad_inputs, x.val, path) end function push_acts!(ad_inputs, x::Duplicated, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse - push_val!(ad_inputs, x.dval, path) + TracedUtils.push_val!(ad_inputs, x.dval, path) end end function push_acts!(ad_inputs, x::DuplicatedNoNeed, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse - push_val!(ad_inputs, x.dval, path) + TracedUtils.push_val!(ad_inputs, x.dval, path) end end function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse ET = eltype(x.val) predims = size(x.val) @@ -237,12 +175,12 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse) ), ) tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...)) - push_val!(ad_inputs, tval, path) + TracedUtils.push_val!(ad_inputs, tval, path) end end function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse ET = eltype(x.val) predims = size(x.val) @@ -252,7 +190,7 @@ function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse) ), ) tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...)) - push_val!(ad_inputs, tval, path) + TracedUtils.push_val!(ad_inputs, tval, path) end end @@ -278,57 +216,6 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) end end -function set!(x, path, tostore; emptypath=false) - for p in path - x = traced_getfield(x, p) - end - - x.mlir_data = tostore - - if emptypath - x.paths = () - end -end - -function get_argidx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :args - return path[2]::Int, path - end - end - throw(AssertionError("No path found for $x")) -end -function get_residx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :result - return path - end - end - throw(AssertionError("No path found $x")) -end - -function has_residx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :result - return true - end - end - return false -end - -function get_attribute_by_name(operation, name) - return MLIR.IR.Attribute(MLIR.API.mlirOperationGetAttributeByName(operation, name)) -end - function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Enzyme.Annotation,Nargs} ) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs} @@ -348,7 +235,7 @@ function overload_autodiff( primf = f.val primargs = ((v.val for v in args)...,) - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn( primf, primargs, (), string(f) * "_autodiff", false ) @@ -356,7 +243,7 @@ function overload_autodiff( ad_inputs = MLIR.IR.Value[] for a in linear_args - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap push!(activity, act_from_type(f, reverse)) push_acts!(ad_inputs, f, path[3:end], reverse) @@ -375,19 +262,24 @@ function overload_autodiff( @inline needs_primal(::Type{<:Enzyme.ForwardMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal for a in linear_results - if has_residx(a) + if TracedUtils.has_residx(a) if needs_primal(CMode) - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) end if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const) if width == 1 - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) else - push!(outtys, batch_ty(width, transpose_ty(MLIR.IR.type(a.mlir_data)))) + push!( + outtys, + TracedUtils.batch_ty( + width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)) + ), + ) end end else - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) end end for (i, act) in enumerate(activity) @@ -395,30 +287,30 @@ function overload_autodiff( if width == 1 push!(outtys, in_tys[i]) else - push!(outtys, batch_ty(width, in_tys[i])) + push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) end end end ret_activity = Int32[] for a in linear_results - if has_residx(a) + if TracedUtils.has_residx(a) act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed - attr = fill(MLIR.IR.Attribute(eltype(a)(1)), mlir_type(a)) + attr = fill(MLIR.IR.Attribute(eltype(a)(1)), Ops.mlir_type(a)) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end else - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap act = act_from_type(f, reverse, true) push!(ret_activity, act) if act != enzyme_out && act != enzyme_outnoneed continue end - push_val!(ad_inputs, f.dval, path[3:end]) + TracedUtils.push_val!(ad_inputs, f.dval, path[3:end]) else if fnwrap idx -= 1 @@ -428,7 +320,7 @@ function overload_autodiff( if act != enzyme_out && act != enzyme_outnoneed continue end - push_val!(ad_inputs, args[idx].dval, path[3:end]) + TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end]) end end end @@ -439,10 +331,10 @@ function overload_autodiff( )::MLIR.API.MlirAttribute return MLIR.IR.Attribute(val) end - fname = get_attribute_by_name(func2, "sym_name") + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( - [transpose_val(v) for v in ad_inputs]; + [TracedUtils.transpose_val(v) for v in ad_inputs]; outputs=outtys, fn=fname, activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), @@ -465,20 +357,20 @@ function overload_autodiff( end for a in linear_results - if has_residx(a) + if TracedUtils.has_residx(a) if needs_primal(CMode) - path = get_residx(a) - tval = transpose_val(MLIR.IR.result(res, residx)) - set!(result, path[2:end], tval) + path = TracedUtils.get_residx(a) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + TracedUtils.set!(result, path[2:end], tval) residx += 1 end if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const) - path = get_residx(a) + path = TracedUtils.get_residx(a) if width == 1 - tval = transpose_val(MLIR.IR.result(res, residx)) - set!(dresult, path[2:end], tval) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + TracedUtils.set!(dresult, path[2:end], tval) else - tval = transpose_val(MLIR.IR.result(res, residx)) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) for i in 1:width sz = size(a) starts = Int64[i] @@ -488,21 +380,29 @@ function overload_autodiff( push!(limits, v) end sval = Ops.slice(sval, starts, limits) - set!(dresult[i], path[2:end], sval) + TracedUtils.set!(dresult[i], path[2:end], sval) end end residx += 1 end else - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap - set!(f.val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) + TracedUtils.set!( + f.val, + path[3:end], + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), + ) residx += 1 else if fnwrap idx -= 1 end - set!(args[idx].val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) + TracedUtils.set!( + args[idx].val, + path[3:end], + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), + ) residx += 1 end end @@ -510,7 +410,7 @@ function overload_autodiff( restup = Any[(a isa Active) ? copy(a) : nothing for a in args] for a in linear_args - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap if act_from_type(f, reverse) != enzyme_out continue @@ -520,7 +420,12 @@ function overload_autodiff( residx += 1 continue end - set_act!(f, path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx))) + set_act!( + f, + path[3:end], + reverse, + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), + ) else if fnwrap idx -= 1 @@ -533,14 +438,17 @@ function overload_autodiff( args[idx], path[3:end], false, - transpose_val(MLIR.IR.result(res, residx)); + TracedUtils.transpose_val(MLIR.IR.result(res, residx)); emptypaths=true, ) #=reverse=# residx += 1 continue end set_act!( - args[idx], path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx)) + args[idx], + path[3:end], + reverse, + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), ) end residx += 1 @@ -572,58 +480,14 @@ function overload_autodiff( end end -@inline function Enzyme.autodiff_deferred( - rmode::Enzyme.ReverseMode{ - ReturnPrimal,RuntimeActivity,ReactantABI,Holomorphic,ErrIfFuncWritten - }, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where { - FA<:Annotation, - A<:Annotation, - ReturnPrimal, - RuntimeActivity, - Holomorphic, - Nargs, - ErrIfFuncWritten, -} - return overload_autodiff(rmode, f, rt, args...) -end - -@inline function Enzyme.autodiff_deferred( - rmode::ForwardMode{ReturnPrimal,ReactantABI,ErrIfFuncWritten,RuntimeActivity}, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where {FA<:Annotation,A<:Annotation,ReturnPrimal,Nargs,ErrIfFuncWritten,RuntimeActivity} - return overload_autodiff(rmode, f, rt, args...) -end - -@inline function Enzyme.autodiff( - rmode::Enzyme.ReverseMode{ - ReturnPrimal,RuntimeActivity,ReactantABI,Holomorphic,ErrIfFuncWritten - }, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where { - FA<:Annotation, - A<:Annotation, - ReturnPrimal, - RuntimeActivity, - Holomorphic, - Nargs, - ErrIfFuncWritten, -} +@reactant_override @noinline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end -@inline function Enzyme.autodiff( - rmode::ForwardMode{ReturnPrimal,ReactantABI,ErrIfFuncWritten,RuntimeActivity}, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where {FA<:Annotation,A<:Annotation,ReturnPrimal,Nargs,ErrIfFuncWritten,RuntimeActivity} +@reactant_override @noinline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end diff --git a/src/Ops.jl b/src/Ops.jl index 013e0dbc8..376122e8b 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -4,21 +4,64 @@ module Ops using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme -using ..Reactant: - Reactant, - ConcreteRArray, - ConcreteRNumber, - TracedRArray, - TracedRNumber, - mlir_type, - mlir_stacktrace +using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue + +function mlir_type(x::RArray{T,N}) where {T,N} + return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) +end + +mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) + +mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) + +function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} + @assert length(shape) == N + return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) +end + +function mlir_type(::Type{<:RNumber{T}}) where {T} + return MLIR.IR.TensorType((), MLIR.IR.Type(T)) +end + +function mlir_type(::Type{<:MissingTracedValue}) + return MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) +end + +const DEBUG_MODE::Ref{Bool} = Ref(false) + +function with_debug(f) + old = DEBUG_MODE[] + DEBUG_MODE[] = true + try + return f() + finally + DEBUG_MODE[] = old + end +end + +@noinline function mlir_stacktrace(name, file, line)::MLIR.IR.Location + # calling `stacktrace` can add a lot of time overhead, so let's avoid adding debug info if not used + if DEBUG_MODE[] + return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + end + + # retrieve current stacktrace, remove this function's frame and translate to MLIR Location + st = stacktrace() + deleteat!(st, 1) + return mapfoldl(MLIR.IR.Location, st) do stackframe + name = string(stackframe.func) + file = stackframe.file + line = stackframe.line + return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + end +end struct Token mlir_data::MLIR.IR.Value end # constant ops -function constant( +@noinline function constant( x::DenseArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T,N} value = MLIR.IR.DenseElementsAttribute(x) @@ -27,28 +70,13 @@ function constant( return TracedRArray{T,N}((), res, size(x)) end -function constant(x::ConcreteRArray; kwargs...) - return stablehlo.constant(Base.convert(Array, x); kwargs...) -end - -function constant( +@noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} res = constant(fill(x); location) return TracedRNumber{T}((), res.mlir_data) end -function constant( - x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) -) where {T} - output = mlir_type(TracedRArray{T,0}, ()) - value = MLIR.IR.DenseElementsAttribute( - fill(MLIR.IR.Attribute(Base.convert(T, x)), output) - ) - res = MLIR.IR.result(stablehlo.constant(; output, value, location)) - return TracedRNumber{T,N}((), res) -end - # unary elementwise ops for (dialect, op) in [ (:stablehlo, :abs), @@ -90,7 +118,7 @@ for (dialect, op) in [ (:chlo, :sinh), ] @eval begin - function $op( + @noinline function $op( x::TracedRArray{T,N}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} @@ -102,7 +130,7 @@ for (dialect, op) in [ return TracedRArray{T,N}((), res, size(x)) end - function $op( + @noinline function $op( x::TracedRNumber{T}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} @@ -138,7 +166,7 @@ for (dialect, op) in [ (:chlo, :zeta), ] @eval begin - function $op( + @noinline function $op( a::TracedRArray{T,N}, b::TracedRArray{T,N}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), @@ -154,7 +182,7 @@ for (dialect, op) in [ return TracedRArray{T,N}((), res, size(a)) end - function $op( + @noinline function $op( a::TracedRNumber{T}, b::TracedRNumber{T}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), @@ -180,7 +208,7 @@ for (dialect, op) in [ (:chlo, :is_pos_inf), ] @eval begin - function $op( + @noinline function $op( x::TracedRArray{T,N}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} @@ -192,7 +220,7 @@ for (dialect, op) in [ return TracedRArray{Bool,N}((), res, size(x)) end - function $op( + @noinline function $op( x::TracedRNumber{T}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} @@ -206,7 +234,7 @@ for (dialect, op) in [ end end -function is_finite( +@noinline function is_finite( x::TracedRArray{T,N}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -217,7 +245,7 @@ function is_finite( return TracedRArray{Bool,N}((), res, size(x)) end -function is_finite( +@noinline function is_finite( x::TracedRNumber{T}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -227,7 +255,7 @@ function is_finite( end # fixes to default automated implementations -function abs( +@noinline function abs( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -236,7 +264,7 @@ function abs( return TracedRArray{T,N}((), res, size(x)) end -function abs( +@noinline function abs( x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -250,7 +278,7 @@ function reshape(x::TracedRArray, dims...; kwargs...) return reshape(x, collect(dims); kwargs...) end -function reshape( +@noinline function reshape( x::TracedRArray{T,N}, dims::Vector{Int}; location=mlir_stacktrace("reshape", @__FILE__, @__LINE__), @@ -265,7 +293,7 @@ function reshape( return transpose(result, Int64[length(dims):-1:1...]) end -function get_dimension_size( +@noinline function get_dimension_size( x::TracedRArray{T,N}, dim; location=mlir_stacktrace("get_dimension_size", @__FILE__, @__LINE__), @@ -279,7 +307,7 @@ function get_dimension_size( return TracedRNumber{Int32}((), res) end -function set_dimension_size( +@noinline function set_dimension_size( x::TracedRArray{T,N}, size::TracedRNumber{Int}, dim::Int; @@ -298,7 +326,7 @@ function set_dimension_size( return TracedRArray{T,N}((), res, size(x)) end -function transpose( +@noinline function transpose( x::TracedRArray{T,N}, permutation; location=mlir_stacktrace("transpose", @__FILE__, @__LINE__), @@ -312,7 +340,7 @@ function transpose( end # indexing ops -function pad( +@noinline function pad( x::TracedRArray{T,N}, padding_value::TracedRNumber{T}; low=fill(0, N), @@ -334,7 +362,7 @@ function pad( return TracedRArray{T,N}((), res, rsize) end -function slice( +@noinline function slice( x::TracedRArray{T,N}, start_indices, limit_indices; @@ -360,7 +388,7 @@ function slice( end # numerics -function complex( +@noinline function complex( real::TracedRArray{T,N}, imag::TracedRArray{T,N}; location=mlir_stacktrace("complex", @__FILE__, @__LINE__), @@ -376,7 +404,7 @@ function complex( return TracedRArray{Complex{T},N}((), res, size(real)) end -function complex( +@noinline function complex( real::TracedRNumber{T}, imag::TracedRNumber{T}; location=mlir_stacktrace("complex", @__FILE__, @__LINE__), @@ -392,7 +420,7 @@ function complex( return TracedRNumber{Complex{T}}((), res) end -function real( +@noinline function real( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("real", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -401,7 +429,7 @@ function real( return TracedRArray{T,N}((), res, size(x)) end -function real( +@noinline function real( x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("real", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -410,7 +438,7 @@ function real( return TracedRNumber{T}((), res) end -function imag( +@noinline function imag( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -419,7 +447,7 @@ function imag( return TracedRArray{T,N}((), res, size(x)) end -function imag( +@noinline function imag( x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -443,7 +471,7 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -function fft( +@noinline function fft( x::TracedRArray{T,N}; type::String, length, @@ -485,7 +513,7 @@ function fft( return TracedRArray{Tout,N}((), res, rsize) end -function cholesky( +@noinline function cholesky( x::TracedRArray{T,N}; lower::Bool=false, location=mlir_stacktrace("cholesky", @__FILE__, @__LINE__), @@ -499,7 +527,7 @@ function cholesky( return TracedRArray{T,N}((), res, size(x)) end -function clamp( +@noinline function clamp( min::Union{TracedRNumber{T},TracedRArray{T,N}}, x::TracedRArray{T,N}, max::Union{TracedRNumber{T},TracedRArray{T,N}}; @@ -517,7 +545,7 @@ function clamp( return TracedRArray{T,N}((), res, size(x)) end -function clamp( +@noinline function clamp( min::TracedRNumber{T}, x::TracedRNumber{T}, max::TracedRNumber{T}; @@ -535,7 +563,9 @@ function clamp( return TracedRNumber{T}((), res) end -function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N} +@noinline function clamp( + min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T +) where {T,N} return clamp(constant(min), x, constant(max)) end @@ -569,7 +599,7 @@ end # return TracedRArray{T,N}((), res, size(lhs)) # end -function dot_general( +@noinline function dot_general( lhs::TracedRArray{T}, rhs::TracedRArray{T}; contracting_dimensions, @@ -726,7 +756,7 @@ function dot_general( return TracedRArray{T,length(ressize)}((), res, ressize) end -function einsum( +@noinline function einsum( lhs::TracedRArray{T}, rhs::TracedRArray{T}; equation::String, @@ -784,23 +814,29 @@ end # end # paralell ops -function partition_id(; location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)) +@noinline function partition_id(; + location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__) +) res = MLIR.IR.result(stablehlo.partition_id(; location)) return TracedRNumber{UInt32}((), res) end -function replica_id(; location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__)) +@noinline function replica_id(; + location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__) +) res = MLIR.IR.result(stablehlo.replica_id(; location)) return TracedRNumber{UInt32}((), res) end -function after_all(tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__)) +@noinline function after_all( + tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__) +) tokens = [token.mlir_data for token in tokens] res = MLIR.IR.result(stablehlo.after_all(tokens; location)) return Token(res) end -function optimization_barrier( +@noinline function optimization_barrier( operands::Union{TracedRNumber,TracedRArray}...; location=mlir_stacktrace("optimization_barrier", @__FILE__, @__LINE__), ) @@ -821,7 +857,7 @@ function optimization_barrier( ) end -function outfeed( +@noinline function outfeed( operands::Union{TracedRNumber,TracedRArray}...; token, config="", @@ -835,7 +871,7 @@ function outfeed( return Token(res) end -function send( +@noinline function send( operands::Union{TracedRNumber,TracedRArray}...; token, channel_id::Int, @@ -858,7 +894,7 @@ function send( return Token(res) end -function recv( +@noinline function recv( results::Tuple{Type,Vector{Int}}...; token, channel_id::Int, @@ -937,7 +973,7 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -function top_k( +@noinline function top_k( x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) ) where {T,N} rsize = [size(x)[1:(end - 1)]..., k] @@ -950,7 +986,7 @@ function top_k( ) end -function iota( +@noinline function iota( T::Type, shape::Vector{Int}; iota_dimension, @@ -963,7 +999,7 @@ function iota( return TracedRArray{T,N}((), res, shape) end -function reverse( +@noinline function reverse( x::TracedRArray{T,N}; dimensions, location=mlir_stacktrace("reverse", @__FILE__, @__LINE__), @@ -980,7 +1016,7 @@ function reverse( end # random ops -function rng_bit_generator( +@noinline function rng_bit_generator( seed::TracedRArray{UInt64,1}, shape; algorithm::String="DEFAULT", @@ -996,7 +1032,7 @@ function rng_bit_generator( end # functional ops -function return_( +@noinline function return_( results::Union{TracedRArray,TracedRNumber}...; location=mlir_stacktrace("return_", @__FILE__, @__LINE__), ) @@ -1004,7 +1040,7 @@ function return_( end # control flow ops -function select( +@noinline function select( pred::Union{TracedRArray{Bool,N},TracedRNumber{Bool}}, on_true::TracedRArray{T,N}, on_false::TracedRArray{T,N}, @@ -1023,7 +1059,7 @@ function select( return TracedRArray{T,N}((), res, size(on_true)) end -function select( +@noinline function select( pred::TracedRNumber{Bool}, on_true::TracedRNumber{T}, on_false::TracedRNumber{T} ) where {T} res = MLIR.IR.result( @@ -1038,20 +1074,15 @@ function select( end # comparison -function compare( - lhs::Union{TracedRArray{T},TracedRNumber{T}}, - rhs::Union{TracedRArray{T},TracedRNumber{T}}; +@noinline function compare( + lhs::AT, + rhs::AT; comparison_direction::String, compare_type=nothing, location=mlir_stacktrace("compare", @__FILE__, @__LINE__), -) where {T} +) where {AT<:Union{TracedRArray,TracedRNumber}} @assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT") @assert size(lhs) == size(rhs) - if lhs isa TracedRNumber - @assert rhs isa TracedRNumber - else - @assert rhs isa TracedRArray - end res = MLIR.IR.result( stablehlo.compare( @@ -1070,7 +1101,7 @@ function compare( end # eltype conversion -function convert( +@noinline function convert( ::Type{TracedRArray{T,N}}, x::TracedRArray; location=mlir_stacktrace("convert", @__FILE__, @__LINE__), @@ -1087,7 +1118,7 @@ function convert( ) end -function convert( +@noinline function convert( ::Type{TracedRNumber{T}}, x::TracedRNumber; location=mlir_stacktrace("convert", @__FILE__, @__LINE__), @@ -1129,7 +1160,7 @@ julia> Reactant.@jit( (ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) ``` """ -function hlo_call( +@noinline function hlo_call( code, args...; func_name="main", diff --git a/src/Reactant.jl b/src/Reactant.jl index 0fc900b24..ba2da588d 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -57,34 +57,6 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end Base.collect(A::RArray) = copy(A) -function Enzyme.make_zero( - ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) -)::RT where {copy_if_inactive,RT<:RArray} - if haskey(seen, prev) - return seen[prev] - end - if Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if RT <: ConcreteRArray - res = RT(zeros(eltype(RT), size(prev))) - seen[prev] = res - return res - end - - if RT <: TracedRArray - res = broadcast_to_size(eltype(RT)(0), size(prev)) - seen[prev] = res - return res - end - - attr = fill(MLIR.IR.Attribute(eltype(RT)(0)), mlir_type(prev)) - cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - res = RT((), cst) - seen[prev] = res - return res -end - function ancestor(x::AbstractArray) p_x = parent(x) p_x === x && return x @@ -97,11 +69,58 @@ include("Interpreter.jl") include("utils.jl") -include("ConcreteRArray.jl") +mutable struct TracedRArray{T,N} <: RArray{T,N} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + shape::NTuple{N,Int} + + function TracedRArray{T,N}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape + ) where {T,N} + shape = Tuple(shape) + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == shape + end + return new{T,N}(paths, mlir_data, shape) + end +end + +const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} +const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} +const AnyTracedRVector{T} = AnyTracedRArray{T,1} +const AnyTracedRMatrix{T} = Union{ + AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} +} +const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} + +function TracedRArray(data::MLIR.IR.Value) + data_type = MLIR.IR.type(data) + return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( + (), data, size(data_type) + ) +end + +mutable struct TracedRNumber{T} <: RNumber{T} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + + function TracedRNumber{T}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {T} + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == () + end + return new{T}(paths, mlir_data) + end +end + +include("Ops.jl") +include("TracedUtils.jl") + include("TracedRNumber.jl") include("TracedRArray.jl") -include("Ops.jl") +include("ConcreteRArray.jl") include("linear_algebra.jl") @@ -111,6 +130,20 @@ include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") +function Enzyme.make_zero( + ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) +)::RT where {copy_if_inactive,RT<:RArray} + if haskey(seen, prev) + return seen[prev] + end + if Enzyme.Compiler.guaranteed_const_nongen(eltype(RT), nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + res = zero(prev) + seen[prev] = res + return res +end + using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 6bdbadc8f..90135e320 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1,41 +1,23 @@ +module TracedRArrayOverrides + using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate -mutable struct TracedRArray{T,N} <: RArray{T,N} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} - shape::NTuple{N,Int} - - function TracedRArray{T,N}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape - ) where {T,N} - shape = Tuple(shape) - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == shape - end - return new{T,N}(paths, mlir_data, shape) - end -end - -const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} -const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRVector{T} = AnyTracedRArray{T,1} -const AnyTracedRMatrix{T} = Union{ - AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} -} -const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} - -function TracedRArray(data::MLIR.IR.Value) - data_type = MLIR.IR.type(data) - return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) -end +import ..TracedRArray +import ..TracedRNumber +import ..ReactantPrimitive +import ..WrappedTracedRArray +import ..AnyTracedRArray +using ..TracedUtils +import ..Ops +import ..MLIR +import ..ancestor +using ReactantCore: ReactantCore +import ..TracedUtils: materialize_traced_array +using GPUArraysCore: GPUArraysCore ReactantCore.is_traced(::TracedRArray) = true -new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) - function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} @assert ndims(x) == N if x isa TracedRArray @@ -49,90 +31,14 @@ end TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x) -materialize_traced_array(x::TracedRArray) = x -materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] -function materialize_traced_array( - x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} -) where {T,N} - return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) -end -function materialize_traced_array( - x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} -) where {T,N} - px = parent(x) - A = ndims(px) == 1 ? reshape(px, :, 1) : px - return permutedims(A, (2, 1)) -end -function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} - return conj(materialize_traced_array(transpose(parent(x)))) -end -function materialize_traced_array( - x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} -) where {T,N,perm,iperm} - return permutedims(parent(x), perm) -end -function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} - return LinearAlgebra.diagm(parent(x)) -end - -get_mlir_data(x::TracedRArray) = x.mlir_data -get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) - -function set_mlir_data!(x::TracedRArray, data) - x.mlir_data = data - return x -end -function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} - res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data - set_mlir_data!(parent(x), res_mlir_data) - return x -end -function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - px.mlir_data = ( - if ndims(px) == 1 - Ops.reshape(tdata, length(tdata)) - else - Ops.transpose(tdata, [2, 1]) - end - ).mlir_data - return x -end -function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - transposed_data = - ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) - px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data - return x -end -function set_mlir_data!( - x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data -) where {T,N,perm,iperm} - parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data - return x -end -function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} - parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data - return x -end -function set_mlir_data!(x::AnyTracedRArray, data) - setindex!(x, TracedRArray(data), axes(x)...) - return x -end - -get_ancestor_indices(::TracedRArray, indices...) = indices -function get_ancestor_indices(x::WrappedTracedRArray, indices...) - return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) -end - function Base.getindex( a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})") - start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] + start_indices = [ + TracedUtils.promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index + ] slice_sizes = [Int64(1) for _ in index] res1 = MLIR.IR.result( @@ -169,7 +75,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} end start_indices = map(indices) do i - return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data + return TracedUtils.promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data end slice_sizes = [Int64(length(i)) for i in indices] res = MLIR.IR.result( @@ -184,11 +90,11 @@ end # Prevent ambiguity function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...) - return getindex(ancestor(a), get_ancestor_indices(a, index...)...) + return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, index...)...) end function Base.getindex(a::WrappedTracedRArray, indices...) - return getindex(ancestor(a), get_ancestor_indices(a, indices...)...) + return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end function Base.setindex!( @@ -199,15 +105,16 @@ function Base.setindex!( indices = map(enumerate(indices)) do (idx, i) i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) end - v = broadcast_to_size(v, length.(indices)) - v = promote_to(TracedRArray{T,N}, v) + v = TracedUtils.broadcast_to_size(v, length.(indices)) + v = TracedUtils.promote_to(TracedRArray{T,N}, v) indices = [ - (promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for - i in indices + ( + TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1 + ).mlir_data for i in indices ] res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_update_slice( - a.mlir_data, get_mlir_data(v), indices + a.mlir_data, TracedUtils.get_mlir_data(v), indices ), 1, ) @@ -220,7 +127,7 @@ function Base.setindex!( v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, ) where {T,N} - ancestor_indices = get_ancestor_indices(a, indices...) + ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...) setindex!(ancestor(a), v, ancestor_indices...) return a end @@ -250,8 +157,9 @@ Base.conj(A::AnyTracedRArray) = A Base.conj(A::AnyTracedRArray{<:Complex}) = Ops.conj(materialize_traced_array(A)) Base.conj!(A::AnyTracedRArray) = A + function Base.conj!(A::AnyTracedRArray{<:Complex}) - set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data) + TracedUtils.set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data) return A end @@ -261,100 +169,9 @@ Base.real(A::AnyTracedRArray{<:Complex}) = Ops.real(materialize_traced_array(A)) Base.imag(A::AnyTracedRArray) = zero(A) Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A)) -promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) - -promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) - -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x -function elem_apply( - ::Type{T}, x::TracedRArray{T2} -) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} - # Special Path to prevent going down a despecialized path - return elem_apply(TypeCast{T}(), x) -end - -function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} - if all(iszero ∘ ndims, args) - scalar_args = map(args) do arg - return promote_to(TracedRNumber{eltype(arg)}, arg) - end - return f(scalar_args...) - end - - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( - f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true - ) - - invmap = IdDict() - for (k, v) in seen_args - invmap[v] = k - end - - keys_seen = [k for k in keys(seen_args) if k isa TracedType] - input_shapes = size.(keys_seen) - # by the time we reach here all args must have same size - @assert allequal(input_shapes) "input shapes are $(input_shapes)" - OutShape = isempty(seen_args) ? nothing : first(input_shapes) - @assert !isnothing(OutShape) - - in_tys2 = [mlir_type(invmap[arg]) for arg in linear_args] - - out_tys2 = [ - MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results - ] - - fname = get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - - for a in linear_args - idx, path = get_argidx(a) - if idx == 1 && fnwrap - push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - res = MLIR.Dialects.enzyme.batch( - batch_inputs; - outputs=out_tys2, - fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), - ) - - residx = 1 - - for a in linear_results - if has_residx(a) - path = get_residx(a) - set!(result, path[2:end], MLIR.IR.result(res, residx)) - residx += 1 - else - idx, path = get_argidx(a) - if idx == 1 && fnwrap - set!(f, path[3:end], MLIR.IR.result(res, residx)) - residx += 1 - else - if fnwrap - idx -= 1 - end - set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) - residx += 1 - end - end - end - - seen_results = OrderedIdDict() - traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=OutShape) - - func2.operation = MLIR.API.MlirOperation(C_NULL) - - return traced2_result +TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) +function TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N} + return TracedUtils.promote_to(TracedRArray{T,N}, rhs) end for (jlop, hloop, hlocomp, merge) in @@ -367,21 +184,6 @@ for (jlop, hloop, hlocomp, merge) in end end -function Enzyme.Compiler.active_reg_inner( - ::Type{TracedRArray{T,N}}, - seen::ST, - world::Union{Nothing,UInt}, - ::Val{justActive}=Val(false), - ::Val{UnionSret}=Val(false), -)::Enzyme.Compiler.ActivityState where {ST,T,N,justActive,UnionSret} - if Enzyme.Compiler.active_reg_inner(T, seen, world, Val(justActive), Val(UnionSret)) == - Enzyme.Compiler.AnyState - return Enzyme.Compiler.AnyState - else - return Enzyme.Compiler.DupState - end -end - function Base.mapreduce( @nospecialize(f), @nospecialize(op), @@ -409,40 +211,54 @@ function Base.mapreduce( init = init::T end - init = [broadcast_to_size(init, ()).mlir_data] + init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] inp = [broadcast(f, A).mlir_data] - rdims = if dims == (:) - Int64[i for i in 0:(N - 1)] + rdims = Int64[] + + if dims == (:) + for i in 0:(N - 1) + push!(rdims, i) + end else - Int64[i - 1 for i in dims] + for i in dims + push!(rdims, i - 1) + end end in_tys = [ - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(arg))) for arg in (inp[1], init[1]) + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))), + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))), ] - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) args = ( - TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, i)) for - (i, ty) in enumerate(in_tys) + TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, 1)), + TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, 2)), ) - res = MLIR.IR.block!(fnbody) do - tmp = broadcast_to_size(op(args...), ()).mlir_data - MLIR.Dialects.stablehlo.return_(MLIR.IR.Value[tmp]) - return tmp + resty = MLIR.IR.block!(fnbody) do + tmp = TracedUtils.broadcast_to_size(op(args...), ()) + Ops.return_(tmp) + return eltype(MLIR.IR.type(tmp.mlir_data)) end - toonedims = [(in(i - 1, rdims) ? 1 : size(A, i)) for i in 1:N] - outdims = [size(A, i) for i in 1:N if (i - 1) ∉ rdims] + toonedims = Int[] + outdims = Int[] + for i in 1:N + tmp = if in(i - 1, rdims) + 1 + else + sz = size(A, i) + push!(outdims, sz) + sz + end + push!(toonedims, tmp) + end - TT = [ - MLIR.IR.TensorType(outdims, eltype(MLIR.IR.type(inp0))) for - (inp0, res0) in zip(inp, (res,)) - ] + TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)] body = MLIR.IR.Region() push!(body, fnbody) @@ -471,19 +287,23 @@ function Base.mapreducedim!( @nospecialize(R::TracedRArray), A::Base.AbstractArrayOrBroadcasted, ) - tmp = broadcast_to_size(Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)) + tmp = TracedUtils.broadcast_to_size( + Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...) + ) R.mlir_data = broadcast(op, R, tmp).mlir_data return R end function Base.fill!(A::TracedRArray{T,N}, x) where {T,N} - bcast = broadcast_to_size(T(x), size(A)) + bcast = TracedUtils.broadcast_to_size(T(x), size(A)) A.mlir_data = bcast.mlir_data return A end function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} - bcast = broadcast_to_size(promote_to(TracedRNumber{T}, x), size(A)) + bcast = TracedUtils.broadcast_to_size( + TracedUtils.promote_to(TracedRNumber{T}, x), size(A) + ) A.mlir_data = bcast.mlir_data return A end @@ -560,77 +380,16 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T, return dest end -broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) - -function broadcast_to_size(arg::Base.RefValue, rsize) - # XXX: don't we want to expand here to rsize? - return arg -end - -broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize))) - -function broadcast_to_size(arg::TracedRNumber, rsize) - length(rsize) == 0 && return arg - return broadcast_to_size_internal( - TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize - ) -end - -function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} - arg = materialize_traced_array(arg) - return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) -end - -function broadcast_to_size(arg::AnyTracedRArray, rsize) - arg = materialize_traced_array(arg) - size(arg) == Tuple(rsize) && return arg - return broadcast_to_size_internal(arg, rsize) -end - -function broadcast_to_size(arg::Broadcast.Extruded, rsize) - rsize2 = (keep ? rsizev : 1 for (keep, rsizev) in zip(arg.keeps, rsize)) - x = broadcast_to_size(arg.x, rsize2) - size(x) == rsize && return x - return broadcast_to_size_internal(x, rsize) -end - -function broadcast_to_size_internal(x::TracedRArray, rsize) - dims = collect(Int64, 0:(length(size(x)) - 1)) - - if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) - @show x - @show arg - @show rsize - @show rsize2 - @show dims - end - @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) - mlirty = MLIR.IR.type(x.mlir_data) - - return TracedRArray{eltype(x),Int(length(rsize))}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.broadcast_in_dim( - x.mlir_data; - result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), - ), - 1, - ), - collect(rsize), - ) -end - function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest bc = Broadcast.preprocess(dest, bc) - args = (broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = elem_apply(bc.f, args...) - set_mlir_data!(dest, res.mlir_data) + res = TracedUtils.elem_apply(bc.f, args...) + TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest end @@ -642,6 +401,7 @@ dispatch_val(::Val{D}) where {D} = D ) where {T} return Base._cat_t(Val(1), T, X...) end + @inline function Base._typed_hcat( ::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray} ) where {T} @@ -684,6 +444,12 @@ function Base._typed_hvncat( return only(As) end +function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} + dims = dispatch_val(dims) + dims ≤ N && return x + return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) +end + function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} dims = dispatch_val(dims) @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." @@ -696,14 +462,14 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} RT = Base.promote_eltype(T, X...) # convert to the target eltype - X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X) + X = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{RT,length(shape)}), X) return TracedRArray{RT,length(shape)}( (), MLIR.IR.result( # TODO maybe we should do some conversion? MLIR.Dialects.stablehlo.concatenate( - collect(get_mlir_data.(X)); + collect(TracedUtils.get_mlir_data.(X)); result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), dimension=dims - 1, # stablehlo expects this to be zero-indexed ), @@ -713,16 +479,10 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} ) end -function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} - dims = dispatch_val(dims) - dims ≤ N && return x - return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) -end - for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp!(x::AnyTracedRArray, min::$(minT), max::$(maxT)) y = Ops.clamp(min, materialize_traced_array(x), max) - set_mlir_data!(x, y.mlir_data) + TracedUtils.set_mlir_data!(x, y.mlir_data) return x end end @@ -731,6 +491,7 @@ Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x) Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x) # outer repeat +# Overridden because we don't need to further recur into the definitions here function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M} P = max(N, M) # potentially padded @@ -744,7 +505,7 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N, broadcast_target_size = interleaved_size broadcast_target_size[2:2:(2M)] .= counts - x_broadcasted = broadcast_to_size(x_interleaved, broadcast_target_size) + x_broadcasted = TracedUtils.broadcast_to_size(x_interleaved, broadcast_target_size) # (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP) final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1)) @@ -753,3 +514,5 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N, return x_final end + +end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index dc7a7ec2a..df664031e 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -1,36 +1,29 @@ -mutable struct TracedRNumber{T} <: RNumber{T} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} +module TracedRNumberOverrides - function TracedRNumber{T}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} - ) where {T} - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == () - end - return new{T}(paths, mlir_data) - end -end - -get_mlir_data(x::TracedRNumber) = x.mlir_data -set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) +import ..TracedRNumber +import ..TracedRArray +import ..ReactantPrimitive +using ..TracedUtils +import ..Ops +import ..MLIR +using ReactantCore ReactantCore.is_traced(::TracedRNumber) = true -new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) - Base.eltype(::Type{TracedRNumber{T}}) where {T} = T Base.getindex(a::TracedRNumber{T}) where {T} = a -Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) -Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) +Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, zero(T)) +Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T)) Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ()) -Base.eps(::Type{TracedRNumber{T}}) where {T} = promote_to(TracedRNumber{T}, eps(T)) +function Base.eps(::Type{TracedRNumber{T}}) where {T} + return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) +end function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} - return promote_to(TracedRNumber{T}, T(x)) + return TracedUtils.promote_to(TracedRNumber{T}, T(x)) end function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} @@ -57,27 +50,33 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} end function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} - return promote_to(TracedRNumber{T}, x) + return TracedUtils.promote_to(TracedRNumber{T}, x) end TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x + function TracedRNumber{T}(x::Number) where {T} - return promote_to(TracedRNumber{T}, x) + return TracedUtils.promote_to(TracedRNumber{T}, x) end -function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} +function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} if rhs isa TracedRNumber rhs isa TracedRNumber{T} && return rhs return Ops.convert(TracedRNumber{T}, rhs) end if rhs isa TracedRArray{<:Any,0} - return promote_to(TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)) + return TracedUtils.promote_to( + TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data) + ) end - rhs isa Number && return promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) - return promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) + rhs isa Number && + return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) + return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) end -promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) +function TracedUtils.promote_to(::TracedRNumber{T}, rhs) where {T} + return TracedUtils.promote_to(TracedRNumber{T}, rhs) +end for (jlop, hloop) in ( (:(Base.min), :minimum), @@ -98,7 +97,7 @@ end function Base.div( @nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown) ) where {T<:Integer} - return Ops.divide(lhs, promote_to(TracedRNumber{T}, rhs)) + return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) end for (jlop, hloop, hlocomp) in ( @@ -117,29 +116,29 @@ for (jlop, hloop, hlocomp) in ( end function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} - return $(jlop)(lhs, promote_to(lhs, rhs)) + return $(jlop)(lhs, TracedUtils.promote_to(lhs, rhs)) end function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) ) where {T} - return $(jlop)(lhs, promote_to(lhs, rhs)) + return $(jlop)(lhs, TracedUtils.promote_to(lhs, rhs)) end function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T} - return $(jlop)(promote_to(rhs, lhs), rhs) + return $(jlop)(TracedUtils.promote_to(rhs, lhs), rhs) end function $(jlop)( @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return $(jlop)(promote_to(rhs, lhs), rhs) + return $(jlop)(TracedUtils.promote_to(rhs, lhs), rhs) end function $(jlop)( @nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2}) ) where {T1,T2} commonTy = TracedRNumber{Base.promote_type(T1, T2)} - lhs = promote_to(commonTy, lhs) - rhs = promote_to(commonTy, rhs) + lhs = TracedUtils.promote_to(commonTy, lhs) + rhs = TracedUtils.promote_to(commonTy, rhs) return $(jlop)(lhs, rhs) end end @@ -154,7 +153,11 @@ function Base.ifelse( element-type to the common type. This is semantically different from the \ behavior of `ifelse` in Base. Use with caution" maxlog = 1 T = promote_type(T1, T2) - return ifelse(pred, promote_to(TracedRNumber{T}, x), promote_to(TracedRNumber{T}, y)) + return ifelse( + pred, + TracedUtils.promote_to(TracedRNumber{T}, x), + TracedUtils.promote_to(TracedRNumber{T}, y), + ) end function Base.ifelse( @@ -170,12 +173,14 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) @eval begin function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.and( - promote_to(TracedRNumber{$(T)}, x), promote_to(TracedRNumber{$(T)}, y) + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.or( - promote_to(TracedRNumber{$(T)}, x), promote_to(TracedRNumber{$(T)}, y) + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) @@ -220,25 +225,23 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN @eval Base.clamp(x::TracedRNumber, min::$(minT), max::$(maxT)) = Ops.clamp(min, x, max) end -struct TypeCast{T<:ReactantPrimitive} <: Function end - -(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) - function Base.fill(x::TracedRNumber, dims::NTuple{N,Integer}) where {N} - return Reactant.broadcast_to_size(x, dims) + return TracedUtils.broadcast_to_size(x, dims) end -Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x) +function Base.float(x::TracedRNumber{T}) where {T} + return TracedUtils.promote_to(TracedRNumber{float(T)}, x) +end # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} - return Base.typed_vcat(T, map(Base.Fix2(broadcast_to_size, (1,)), x)...) + return Base.typed_vcat(T, map(Base.Fix2(TracedUtils.broadcast_to_size, (1,)), x)...) end Base.hcat(x::TracedRNumber...) = Base.typed_hcat(Base.promote_eltypeof(x...), x...) function Base.typed_hcat(::Type{T}, x::TracedRNumber...) where {T} - return Base.typed_hcat(T, map(Base.Fix2(broadcast_to_size, (1, 1)), x)...) + return Base.typed_hcat(T, map(Base.Fix2(TracedUtils.broadcast_to_size, (1, 1)), x)...) end function Base.hvcat(rows::Tuple{Vararg{Int}}, xs::TracedRNumber...) @@ -247,7 +250,7 @@ end function Base.typed_hvcat( ::Type{T}, rows::Tuple{Vararg{Int}}, xs::TracedRNumber... ) where {T} - xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + xs = map(Base.Fix2(TracedUtils.broadcast_to_size, (1, 1)), xs) return Base.typed_hvcat(T, rows, xs...) end @@ -257,6 +260,8 @@ end function Base.typed_hvncat( ::Type{T}, dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber... ) where {T} - xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + xs = map(Base.Fix2(TracedUtils.broadcast_to_size, (1, 1)), xs) return Base.typed_hvncat(T, dims, row_first, xs...) end + +end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl new file mode 100644 index 000000000..d4fc11e94 --- /dev/null +++ b/src/TracedUtils.jl @@ -0,0 +1,530 @@ +# Functions within this module and Ops do not get forcibly re-compiled to be within our interpreter. +# This means that replacements, for example, for autodiff/random/kernels/etc do not get applied here when +# within compilation. However, it means these functions are a _lot_ faster to compile. +module TracedUtils + +using LinearAlgebra: LinearAlgebra +using Adapt: Adapt +using ..Reactant: + RArray, + RNumber, + TracedRArray, + TracedRNumber, + WrappedTracedRArray, + AnyTracedRArray, + MissingTracedValue, + OrderedIdDict +import ..Reactant +import ..Reactant.MLIR +import ..ReactantPrimitive +import ..Ops + +materialize_traced_array(x::TracedRArray) = x +materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] +function materialize_traced_array( + x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} +) where {T,N} + return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) +end +function materialize_traced_array( + x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} +) where {T,N} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end +function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) +end +function materialize_traced_array( + x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} +) where {T,N,perm,iperm} + return permutedims(parent(x), perm) +end +function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} + return LinearAlgebra.diagm(parent(x)) +end + +get_mlir_data(x::TracedRNumber) = x.mlir_data +set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) + +get_mlir_data(x::TracedRArray) = x.mlir_data +get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) + +function set_mlir_data!(x::TracedRArray, data) + x.mlir_data = data + return x +end +function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} + res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + set_mlir_data!(parent(x), res_mlir_data) + return x +end +function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end +function set_mlir_data!( + x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data +) where {T,N,perm,iperm} + parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} + parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data + return x +end +function set_mlir_data!(x::AnyTracedRArray, data) + setindex!(x, TracedRArray(data), axes(x)...) + return x +end + +get_ancestor_indices(::TracedRArray, indices...) = indices +function get_ancestor_indices(x::WrappedTracedRArray, indices...) + return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) +end + +function batch_ty(width, mlirty) + return MLIR.IR.TensorType([width, size(mlirty)...], eltype(mlirty)) +end + +function transpose_ty(mlirty) + return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) +end +function transpose_val(val) + attr = MLIR.IR.DenseArrayAttribute( + Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] + ) + return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) +end + +function make_mlir_fn( + f, + args, + kwargs, + name="main", + concretein=true; + toscalar=false, + return_dialect=:func, + no_args_in_result::Bool=false, + construct_function_without_args::Bool=false, + do_transpose=true, +) + if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction + return ( + true, + make_mlir_fn( + Reactant.apply, + (f, args...), + kwargs, + name, + concretein; + toscalar, + return_dialect, + no_args_in_result, + construct_function_without_args, + do_transpose, + )[2:end]..., + ) + end + + N = length(args) + seen_args = OrderedIdDict() + traced_args = ntuple(N) do i + return Reactant.make_tracer( + seen_args, + args[i], + (:args, i), + concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; + toscalar, + track_numbers=construct_function_without_args ? (Number,) : (), + ) + end + + linear_args = Reactant.TracedType[] + for (k, v) in seen_args + v isa Reactant.TracedType || continue + push!(linear_args, v) + end + + in_tys = if toscalar + [MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args] + elseif do_transpose + [transpose_ty(Ops.mlir_type(arg)) for arg in linear_args] + else + [Ops.mlir_type(arg) for arg in linear_args] + end + + sym_visibility = nothing + if !concretein + sym_visibility = MLIR.IR.Attribute("private") + end + + mod = MLIR.IR.mmodule() + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) + end + + if construct_function_without_args + fnbody = MLIR.IR.Block() + else + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) + end + push!(MLIR.IR.region(func, 1), fnbody) + + @assert MLIR.IR._has_block() + + result = MLIR.IR.block!(fnbody) do + for (i, arg) in enumerate(linear_args) + if construct_function_without_args + arg.mlir_data = args[i].mlir_data + else + raw_arg = MLIR.IR.argument(fnbody, i) + row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg + arg.mlir_data = row_maj_arg + end + end + + # TODO fix it for kwargs + #if concretein + Reactant.call_with_reactant(f, traced_args...) + #else + # f(traced_args...) + #end + end + + seen_results = OrderedIdDict() + + traced_result = Reactant.make_tracer( + seen_results, + result, + (:result,), + concretein ? Reactant.TracedTrack : Reactant.TracedSetPath; + track_numbers=construct_function_without_args ? (Number,) : (), + ) + + # marks buffers to be donated + for i in 1:N + Reactant.make_tracer( + seen_results, + traced_args[i], + concretein ? (:resargs, i) : (), + Reactant.TracedTrack, + ) + end + + linear_results = Reactant.TracedType[] + + for (k, v) in seen_results + v isa Reactant.TracedType || continue + (no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue + push!(linear_results, v) + end + + out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] + + ret = MLIR.IR.block!(fnbody) do + vals = MLIR.IR.Value[] + for res in linear_results + col_maj = if res isa MissingTracedValue + broadcast_to_size(false, ()).mlir_data + elseif construct_function_without_args || !do_transpose + res.mlir_data + elseif do_transpose + transpose_val(res.mlir_data) + end + push!(vals, col_maj) + end + !no_args_in_result && @assert length(vals) == length(linear_results) + + dialect = getfield(MLIR.Dialects, return_dialect) + return dialect.return_(vals) + end + + name2 = name + + tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) + for i in 0:10000 + name2 = if i == 0 + name + else + name * string(i) + end + if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2)) + break + end + end + + func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name2, + function_type=MLIR.IR.FunctionType(in_tys, out_tys), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) + + MLIR.API.mlirOperationDestroy(func.operation) + func.operation = MLIR.API.MlirOperation(C_NULL) + return ( + false, + func2, + traced_result, + result, + seen_args, + ret, + linear_args, + in_tys, + linear_results, + ) +end + +elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x + +struct TypeCast{T<:ReactantPrimitive} <: Function end + +function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} + return TracedUtils.promote_to(TracedRNumber{T}, x) +end + +function elem_apply( + ::Type{T}, x::TracedRArray{T2} +) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} + # Special Path to prevent going down a despecialized path + return elem_apply(TypeCast{T}(), x) +end + +function promote_to end + +function get_attribute_by_name(operation, name) + return MLIR.IR.Attribute(MLIR.API.mlirOperationGetAttributeByName(operation, name)) +end + +function push_val!(ad_inputs, x, path) + for p in path + x = traced_getfield(x, p) + end + x = x.mlir_data + return push!(ad_inputs, x) +end + +function get_argidx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :args + return path[2]::Int, path + end + end + throw(AssertionError("No path found for $x")) +end + +function set!(x, path, tostore; emptypath=false) + for p in path + x = traced_getfield(x, p) + end + + x.mlir_data = tostore + + if emptypath + x.paths = () + end +end + +function get_residx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :result + return path + end + end + throw(AssertionError("No path found $x")) +end + +function has_residx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :result + return true + end + end + return false +end + +function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return promote_to(TracedRNumber{eltype(arg)}, arg) + end + return f(scalar_args...) + end + + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true + ) + + invmap = IdDict() + for (k, v) in seen_args + invmap[v] = k + end + + keys_seen = [k for k in keys(seen_args) if k isa Reactant.TracedType] + input_shapes = size.(keys_seen) + # by the time we reach here all args must have same size + @assert allequal(input_shapes) "input shapes are $(input_shapes)" + OutShape = isempty(seen_args) ? nothing : first(input_shapes) + @assert !isnothing(OutShape) + + in_tys2 = [Ops.mlir_type(invmap[arg]) for arg in linear_args] + + out_tys2 = [ + MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results + ] + + fname = get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + + for a in linear_args + idx, path = TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + res = MLIR.Dialects.enzyme.batch( + batch_inputs; + outputs=out_tys2, + fn=fname, + batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), + ) + + residx = 1 + + for a in linear_results + if TracedUtils.has_residx(a) + path = TracedUtils.get_residx(a) + TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx)) + residx += 1 + else + idx, path = TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx)) + residx += 1 + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + residx += 1 + end + end + end + + seen_results = OrderedIdDict() + traced2_result = Reactant.make_tracer( + seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape + ) + + func2.operation = MLIR.API.MlirOperation(C_NULL) + + return traced2_result +end + +new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) +new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) + +broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) + +function broadcast_to_size(arg::Base.RefValue, rsize) + # XXX: don't we want to expand here to rsize? + return arg +end + +broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize))) + +function broadcast_to_size(arg::TracedRNumber, rsize) + length(rsize) == 0 && return arg + return broadcast_to_size_internal( + TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize + ) +end + +function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} + arg = materialize_traced_array(arg) + return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) +end + +function broadcast_to_size(arg::AnyTracedRArray, rsize) + arg = materialize_traced_array(arg) + size(arg) == Tuple(rsize) && return arg + return broadcast_to_size_internal(arg, rsize) +end + +function broadcast_to_size(arg::Broadcast.Extruded, rsize) + rsize2 = (keep ? rsizev : 1 for (keep, rsizev) in zip(arg.keeps, rsize)) + x = broadcast_to_size(arg.x, rsize2) + size(x) == rsize && return x + return broadcast_to_size_internal(x, rsize) +end + +@noinline function broadcast_to_size_internal(x::TracedRArray, rsize) + dims = collect(Int64, 0:(length(size(x)) - 1)) + + if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) + @show x + @show arg + @show rsize + @show rsize2 + @show dims + end + @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) + mlirty = MLIR.IR.type(x.mlir_data) + + return TracedRArray{eltype(x),Int(length(rsize))}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), + ), + 1, + ), + collect(rsize), + ) +end + +end diff --git a/src/Tracing.jl b/src/Tracing.jl index 4ea8172aa..62bb71f69 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -498,14 +498,18 @@ function make_tracer( return ConcreteRNumber(prev) else if mode == TracedTrack - res = TracedRNumber{RT}((path,), broadcast_to_size(prev, ()).mlir_data) + res = TracedRNumber{RT}( + (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data + ) if !haskey(seen, prev) return seen[prev] = res end return res elseif mode == TracedSetPath haskey(seen, prev) && return seen[prev] - res = TracedRNumber{RT}((path,), broadcast_to_size(prev, ()).mlir_data) + res = TracedRNumber{RT}( + (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data + ) seen[prev] = res return res elseif mode == TracedToConcrete diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 745195217..c011f8aec 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,3 +1,19 @@ +module TracedLinearAlgebra + +using ..Reactant +import ..TracedRArray +import ..TracedRNumber +import ..AnyTracedRArray +import ..AnyTracedRMatrix +import ..AnyTracedRVector + +import ..TracedUtils +using ..TracedUtils: get_mlir_data, materialize_traced_array, set_mlir_data! + +import ..Ops +import ..MLIR +using LinearAlgebra + function LinearAlgebra.mul!( @nospecialize(C::TracedRArray{T,1}), @nospecialize(A::AnyTracedRMatrix), @@ -48,10 +64,10 @@ function LinearAlgebra.mul!( ) res = if iszero(β) - isone(α) ? tmp : Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) + isone(α) ? tmp : Ops.multiply(tmp, TracedUtils.broadcast_to_size(T(α), size(C))) else - α_res = Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) - β_C = Ops.multiply(C, broadcast_to_size(T(β), size(C))) + α_res = Ops.multiply(tmp, TracedUtils.broadcast_to_size(T(α), size(C))) + β_C = Ops.multiply(C, TracedUtils.broadcast_to_size(T(β), size(C))) Ops.add(α_res, β_C) end set_mlir_data!(C, get_mlir_data(res)) @@ -61,7 +77,8 @@ end function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = Ops.subtract( - Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + Ops.iota(Int64, [size(X)...]; iota_dimension=2), + TracedUtils.broadcast_to_size(k, size(X)), ) idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE") X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data @@ -71,7 +88,8 @@ end function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = Ops.subtract( - Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + Ops.iota(Int64, [size(X)...]; iota_dimension=2), + TracedUtils.broadcast_to_size(k, size(X)), ) idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE") X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data @@ -99,9 +117,9 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # terminate called after throwing an instance of 'xla::XlaRuntimeError' # what(): UNKNOWN: :0: error: 'tensor.empty' op unsupported op for export to XLA # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> - length(indices) ≤ 0 && return promote_to(TracedRArray{T,1}, T[]) + length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,2}, indices)) + idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,2}, indices)) #! format: off dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( @@ -115,7 +133,9 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} ) #! format: on - slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [1, 1])) + slice_sizes = get_mlir_data( + Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1]) + ) res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_gather( get_mlir_data(y), idxs, slice_sizes; dimension_numbers @@ -139,6 +159,10 @@ function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) wh mat = (v .+ zero(v)') .* diag_indicator return Ops.pad( - mat, promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] + mat, + TracedUtils.promote_to(TracedRNumber{T}, 0); + high=[m - length(v), n - length(v)], ) end + +end diff --git a/src/utils.jl b/src/utils.jl index b37e00fd1..b65077c03 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,278 +1,616 @@ -function mlir_type(x::RArray{T,N}) where {T,N} - return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) -end -mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) +function apply(f, args...; kwargs...) + return f(args...; kwargs...) +end -mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) +function call_with_reactant end -function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} - @assert length(shape) == N - return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) +function maybe_argextype(@nospecialize(x), src) + return try + Core.Compiler.argextype(x, src) + catch err + !(err isa Core.Compiler.InvalidIRError) && rethrow() + nothing + end end -function mlir_type(::Type{<:RNumber{T}}) where {T} - return MLIR.IR.TensorType((), MLIR.IR.Type(T)) -end +""" + Reactant.REDUB_ARGUMENTS_NAME -function mlir_type(::Type{<:MissingTracedValue}) - return MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) -end +The variable name bound to `call_with_reactant`'s tuple of arguments in its +`@generated` method definition. -function batch_ty(width, mlirty) - return MLIR.IR.TensorType([width, size(mlirty)...], eltype(mlirty)) -end +This binding can be used to manually reference/destructure `call_with_reactants` arguments -function transpose_ty(mlirty) - return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) +This is required because user arguments could have a name which clashes with whatever name we choose for +our argument. Thus we gensym to create it. + +This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 +""" +const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") + +function throw_method_error(argtys) + throw(MethodError(argtys[1], argtys[2:end])) end -function transpose_val(val) - attr = MLIR.IR.DenseArrayAttribute( - Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] + +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Union{Nothing,Core.MethodTable}, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) + res = ccall( + :jl_gf_invoke_lookup_worlds, + Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, + mt, + world, + min_world, + max_world, ) - return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) + return res end -function apply(f, args...; kwargs...) - return f(args...; kwargs...) +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Core.Compiler.InternalMethodTable, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) + res = lookup_world(sig, mt.world, nothing, min_world, max_world) + return res end -function make_mlir_fn( - f, - args, - kwargs, - name="main", - concretein=true; - toscalar=false, - return_dialect=:func, - no_args_in_result::Bool=false, - construct_function_without_args::Bool=false, - do_transpose=true, +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Core.Compiler.OverlayMethodTable, + min_world::Ref{UInt}, + max_world::Ref{UInt}, ) - if sizeof(typeof(f)) != 0 || f isa BroadcastFunction - return ( - true, - make_mlir_fn( - apply, - (f, args...), - kwargs, - name, - concretein; - toscalar, - return_dialect, - no_args_in_result, - construct_function_without_args, - do_transpose, - )[2:end]..., - ) + res = lookup_world(sig, mt.world, mt.mt, min_world, max_world) + if res !== nothing + return res + else + return lookup_world(sig, mt.world, nothing, min_world, max_world) end +end - N = length(args) - seen_args = OrderedIdDict() - traced_args = ntuple(N) do i - return make_tracer( - seen_args, - args[i], - (:args, i), - concretein ? ConcreteToTraced : TracedSetPath; - toscalar, - track_numbers=construct_function_without_args ? (Number,) : (), - ) +function has_ancestor(query::Module, target::Module) + query == target && return true + while true + next = parentmodule(query) + next == target && return true + next == query && return false + query = next end +end - linear_args = TracedType[] - for (k, v) in seen_args - v isa TracedType || continue - push!(linear_args, v) +function should_rewrite_ft(@nospecialize(ft)) + # Don't rewrite builtin or intrinsics + if ft <: Core.IntrinsicFunction || ft <: Core.Builtin + return false + end + if ft <: Core.Function + mod = ft.name.module + # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions + if has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) + return false + end + end + # Don't rewrite Val + if ft === Type{Base.Val} + return false + end + # Don't rewrite exception constructors + if ft <: Type{<:Core.Exception} + return false end - in_tys = if toscalar - [MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args] - elseif do_transpose - [transpose_ty(mlir_type(arg)) for arg in linear_args] - else - [mlir_type(arg) for arg in linear_args] + # Avoid the 1.10 stackoverflow + if ft <: typeof(Base.typed_hvcat) + return false + end + if ft <: typeof(Base.hvcat) + return false end - sym_visibility = nothing - if !concretein - sym_visibility = MLIR.IR.Attribute("private") + # Don't rewrite traced constructors + if ft <: Type{<:TracedRArray} || + ft <: Type{<:TracedRNumber} || + ft === Type{MLIR.IR.Location} || + ft === Type{MLIR.IR.Block} + return false end - mod = MLIR.IR.mmodule() - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region(), - ) + # Perf optimizations + if ft <: typeof(Core.Compiler.return_type) + return false end - if construct_function_without_args - fnbody = MLIR.IR.Block() - else - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) + # Perf optimizations + if ft <: typeof(Base.typemax) || + ft <: typeof(Base.typemin) || + ft <: typeof(Base.getproperty) || + ft <: typeof(Base.vect) || + ft <: typeof(Base.eltype) + return false end - push!(MLIR.IR.region(func, 1), fnbody) - @assert MLIR.IR._has_block() + # Default assume all functions need to be reactant-ified + return true +end - result = MLIR.IR.block!(fnbody) do - for (i, arg) in enumerate(linear_args) - if construct_function_without_args - arg.mlir_data = args[i].mlir_data - else - raw_arg = MLIR.IR.argument(fnbody, i) - row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg - arg.mlir_data = row_maj_arg - end +# Avoid recursively interpreting into methods we define explicitly +# as overloads, which we assume should handle the entirety of the +# translation (and if not they can use call_in_reactant). +function is_reactant_method(mi::Core.MethodInstance) + meth = mi.def + if !isdefined(meth, :external_mt) + return false + end + mt = meth.external_mt + return mt === REACTANT_METHOD_TABLE +end + +function rewrite_inst(inst, ir, interp) + if Meta.isexpr(inst, :call) + # Even if type unstable we do not want (or need) to replace intrinsic + # calls or builtins with our version. + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) + if ft == typeof(Core.kwcall) + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) end + if should_rewrite_ft(ft) + rep = Expr(:call, call_with_reactant, inst.args...) + return true, rep + end + end + if Meta.isexpr(inst, :invoke) + omi = inst.args[1]::Core.MethodInstance + sig = omi.specTypes + ft = sig.parameters[1] + if ft == typeof(Core.kwcall) + ft = sig.parameters[3] + end + if should_rewrite_ft(ft) && !is_reactant_method(omi) + method = omi.def::Core.Method - interp = ReactantInterpreter() + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) - # TODO replace with `Base.invoke_within` if julia#52964 lands - # TODO fix it for kwargs - ircoderes = Base.code_ircode(f, map(typeof, traced_args); interp) + if !method.isva || !Base.isvarargtype(sig.parameters[end]) + sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} + else + vartup = inst.args[end] + ns = Type[] + eT = sig.parameters[end].T + for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) + push!(ns, eT) + end + sig2 = Tuple{ + typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... + } + end - if length(ircoderes) != 1 - throw( - AssertionError( - "Could not find unique ircode for $f $traced_args, found $ircoderes" - ), + lookup_result = lookup_world( + sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world ) - end - ir, ty = ircoderes[1] - oc = Core.OpaqueClosure(ir) - if f === Reactant.apply - oc(traced_args[1], (traced_args[2:end]...,)) - else - if (length(traced_args) + 1 != length(ir.argtypes)) || ( - length(traced_args) > 0 && - length(ir.argtypes) > 0 && - !(last(ir.argtypes) isa Core.Const) && - last(ir.argtypes) != typeof(traced_args[end]) + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, ) - @assert ir.argtypes[end] <: Tuple - oc( - traced_args[1:(length(ir.argtypes) - 2)]..., - (traced_args[(length(ir.argtypes) - 1):end]...,), - ) - else - oc(traced_args...) - end + n_method_args = method.nargs + rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) + return true, rep end end + return false, inst +end + +const oc_captures = Dict{Tuple{Type,Type,Core.CodeInfo,Int,Bool,Any},Core.OpaqueClosure}() + +# Caching is both good to reducing compile times and necessary to work around julia bugs +# in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 +function make_oc( + sig::Type, rt::Type, src::Core.CodeInfo, nargs::Int, isva::Bool, f::Any +)::Core.OpaqueClosure + key = (sig, rt, src, nargs, isva, f) + if haskey(oc_captures, key) + return oc_captures[key] + else + ores = ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + sig, + rt, + rt, + @__MODULE__, + src, + 0, + nothing, + nargs, + isva, + f, + true, + )::Core.OpaqueClosure + oc_captures[key] = ores + return ores + end +end + +function safe_print(name, x) + return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) +end + +const DEBUG_INTERP = Ref(false) + +# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter +# In particular this entails two pieces: +# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance +# 2) Post type inference (using of course the reactant interpreter), all type unstable call functions are +# replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia +# using a custom interpreter in type unstable code. +# `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` +function call_with_reactant_generator( + world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments) +) + @nospecialize + args = redub_arguments + if DEBUG_INTERP[] + safe_print("args", args) + end + + stub = Core.GeneratedFunctionStub( + identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() + ) + + # look up the method match + builtin_error = :(throw( + AssertionError("Unsupported call_with_reactant of builtin $redub_arguments") + )) + + if args[1] <: Core.Builtin + return stub(world, source, builtin_error) + end + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) + )) - seen_results = OrderedIdDict() + interp = ReactantInterpreter(; world) - traced_result = make_tracer( - seen_results, - result, - (:result,), - concretein ? TracedTrack : TracedSetPath; - track_numbers=construct_function_without_args ? (Number,) : (), + sig = Tuple{args...} + + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + lookup_result = lookup_world( + sig, world, Core.Compiler.method_table(interp), min_world, max_world ) - # marks buffers to be donated - for i in 1:N - make_tracer( - seen_results, traced_args[i], concretein ? (:resargs, i) : (), TracedTrack + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] + + # No method could be found (including in our method table), bail with an error + if lookup_result == nothing + return stub(world, source, method_error) + tmp_min_world = Ref{UInt}(typemin(UInt)) + tmp_max_world = Ref{UInt}(typemax(UInt)) + match = ccall( + :jl_gf_invoke_lookup_worlds, + Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + Tuple{typeof(throw_method_error),sig}, + nothing, + world, + tmp_min_world, + tmp_max_world, + ) #=mt=# + @assert match !== nothing + + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, ) + + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + src = copy(ci) + src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] + + src.edges = Any[ + ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig + ] + src.min_world = min_world[] + src.max_world = max_world[] + + push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), 1))) + push!(overdubbed_codelocs, 0) + + expr_fn = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2))))) + push!(overdubbed_codelocs, 0) + + expr_lastindex = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :(2:($expr_lastindex))) + push!(overdubbed_codelocs, 0) + + expr_slice = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), $expr_slice))) + push!(overdubbed_codelocs, 0) + + expr_args = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.MethodError)($expr_fn, $expr_args, $world))) + push!(overdubbed_codelocs, 0) + + expr_method = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.throw)($expr_method))) + push!(overdubbed_codelocs, 0) + + push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) + push!(overdubbed_codelocs, 0) + + src.code = overdubbed_code + src.codelocs = overdubbed_codelocs + src.ssavaluetypes = length(overdubbed_code) + src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + + return src end - linear_results = TracedType[] + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) - for (k, v) in seen_results - v isa TracedType || continue - (no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue - push!(linear_results, v) + result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) + frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :no, interp) #=cache_mode=# + @assert frame !== nothing + Core.Compiler.typeinf(interp, frame) + @static if VERSION >= v"1.11" + # `typeinf` doesn't update the cfg. We need to do it manually. + # frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) end + @assert Core.Compiler.is_inferred(frame) - out_tys = [transpose_ty(mlir_type(arg)) for arg in linear_results] - - ret = MLIR.IR.block!(fnbody) do - vals = MLIR.IR.Value[] - for res in linear_results - col_maj = if res isa MissingTracedValue - broadcast_to_size(false, ()).mlir_data - elseif construct_function_without_args || !do_transpose - res.mlir_data - elseif do_transpose - transpose_val(res.mlir_data) - end - push!(vals, col_maj) - end - !no_args_in_result && @assert length(vals) == length(linear_results) + method = match.method - dialect = getfield(MLIR.Dialects, return_dialect) - return dialect.return_(vals) + # The original julia code (on 1.11+) has the potential constprop, for now + # we assume this outermost function does not constprop, for ease. + #if Core.Compiler.result_is_constabi(interp, frame.result) + # rt = frame.result.result::Core.Compiler.Const + # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) + #else + # + opt = Core.Compiler.OptimizationState(frame, interp) + + if DEBUG_INTERP[] + safe_print("opt.src", opt.src) end - name2 = name + caller = frame.result + @static if VERSION < v"1.11-" + ir = Core.Compiler.run_passes(opt.src, opt, caller) + else + ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) + @static if VERSION < v"1.12-" + else + Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) + end + end - tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) - for i in 0:10000 - name2 = if i == 0 - name + if DEBUG_INTERP[] + safe_print("ir1", ir) + end + + # Rewrite type unstable calls to recurse into call_with_reactant to ensure + # they continue to use our interpreter. Reset the derived return type + # to Any if our interpreter would change the return type of any result. + # Also rewrite invoke (type stable call) to be :call, since otherwise apparently + # screws up type inference after this (TODO this should be fixed). + any_changed = false + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + changed, next = rewrite_inst(inst[:inst], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :inst) else - name * string(i) + changed, next = rewrite_inst(inst[:stmt], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :stmt) end - if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2)) - break + if changed + any_changed = true + Core.Compiler.setindex!(ir.stmts[i], Any, :type) end end - func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name2, - function_type=MLIR.IR.FunctionType(in_tys, out_tys), - body=MLIR.IR.Region(), - sym_visibility, + Core.Compiler.finish(interp, opt, ir, caller) + + src = Core.Compiler.ir_to_codeinf!(opt) + + if DEBUG_INTERP[] + safe_print("src", src) + end + + # prepare a new code info + code_info = copy(src) + static_params = match.sparams + signature = sig + + # propagate edge metadata, this method is invalidated if the original function we are calling + # is invalidated + code_info.edges = Core.MethodInstance[mi] + code_info.min_world = min_world[] + code_info.max_world = max_world[] + + # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, + # and the REDUB_ARGUMENTS_NAME tuple of input arguments + code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] + code_info.slotflags = UInt8[0x00, 0x00] + n_prepended_slots = 2 + overdub_args_slot = Core.SlotNumber(n_prepended_slots) + + # For the sake of convenience, the rest of this pass will translate `code_info`'s fields + # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at + # the end of the pass, we'll reset `code_info` fields accordingly. + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] + # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention + # required by the base method. + + # destructure the generated argument slots into the overdubbed method's argument slots. + + offset = 1 + fn_args = Any[] + n_method_args = method.nargs + n_actual_args = length(redub_arguments) + + tys = [] + + iter_args = n_actual_args + if method.isva + iter_args = min(n_actual_args, n_method_args - 1) + end + + for i in 1:iter_args + actual_argument = Expr( + :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) + push!(overdubbed_code, actual_argument) + push!(overdubbed_codelocs, code_info.codelocs[1]) + offset += 1 + push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!(tys, redub_arguments[i]) + + if DEBUG_INTERP[] + push!( + overdubbed_code, + Expr( + :call, + safe_print, + "fn arg[" * string(length(fn_args)) * "]", + fn_args[end], + ), + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end end - MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) - - MLIR.API.mlirOperationDestroy(func.operation) - func.operation = MLIR.API.MlirOperation(C_NULL) - return ( - false, - func2, - traced_result, - result, - seen_args, - ret, - linear_args, - in_tys, - linear_results, - ) -end -const DEBUG_MODE::Ref{Bool} = Ref(false) + # If `method` is a varargs method, we have to restructure the original method call's + # trailing arguments into a tuple and assign that tuple to the expected argument slot. + if method.isva + trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) + for i in n_method_args:n_actual_args + push!( + overdubbed_code, + Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset), + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) + offset += 1 + end -function with_debug(f) - old = DEBUG_MODE[] - DEBUG_MODE[] = true - try - return f() - finally - DEBUG_MODE[] = old + push!(overdubbed_code, trailing_arguments) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) + + if DEBUG_INTERP[] + push!( + overdubbed_code, + Expr( + :call, + safe_print, + "fn arg[" * string(length(fn_args)) * "]", + fn_args[end], + ), + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end end -end -function mlir_stacktrace(name, file, line)::MLIR.IR.Location - # calling `stacktrace` can add a lot of time overhead, so let's avoid adding debug info if not used - if DEBUG_MODE[] - return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + rt = Base.Experimental.compute_ir_rettype(ir) + + # ocva = method.isva + + ocva = false # method.isva + + ocnargs = method.nargs - 1 + # octup = Tuple{mi.specTypes.parameters[2:end]...} + # octup = Tuple{method.sig.parameters[2:end]...} + octup = Tuple{tys[2:end]...} + ocva = false + + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right + # inner code during compilation without special handling (i.e. call_in_world_total). + # Opaque closures also require takign the function argument. We can work around the latter + # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure + oc = if false && Base.issingletontype(args[1]) + res = Core._call_in_world_total( + world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance + )::Core.OpaqueClosure + + else + farg = fn_args[1] + push!(overdubbed_code, Expr(:call, make_oc, octup, rt, src, ocnargs, ocva, farg)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + Core.SSAValue(length(overdubbed_code)) end - # retrieve current stacktrace, remove this function's frame and translate to MLIR Location - st = stacktrace() - deleteat!(st, 1) - return mapfoldl(MLIR.IR.Location, st) do stackframe - name = string(stackframe.func) - file = stackframe.file - line = stackframe.line - return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + push!(overdubbed_code, Expr(:(call), oc, fn_args[2:end]...)) + + push!(overdubbed_codelocs, code_info.codelocs[1]) + + push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) + push!(overdubbed_codelocs, code_info.codelocs[1]) + + #=== set `code_info`/`reflection` fields accordingly ===# + + if code_info.method_for_inference_limit_heuristics === nothing + code_info.method_for_inference_limit_heuristics = method + end + + code_info.code = overdubbed_code + code_info.codelocs = overdubbed_codelocs + code_info.ssavaluetypes = length(overdubbed_code) + code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + + if DEBUG_INTERP[] + safe_print("code_info", code_info) end + + return code_info +end + +@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) + $(Expr(:meta, :generated_only)) + return $(Expr(:meta, :generated, call_with_reactant_generator)) end diff --git a/test/basic.jl b/test/basic.jl index 75859122c..5eff286ed 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -640,7 +640,7 @@ end function f_row_major(x) y = [1 2; 3 4; 5 6] if x isa Reactant.TracedRArray - y = Reactant.promote_to(Reactant.TracedRArray{eltype(x),2}, y) + y = Reactant.TracedUtils.promote_to(Reactant.TracedRArray{eltype(x),2}, y) end return x .+ y end diff --git a/test/complex.jl b/test/complex.jl index 3bf19a051..43e3c4f6b 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -92,7 +92,7 @@ end y = Reactant.ConcreteRNumber(x) f = Reactant.compile((y,)) do z - z + Reactant.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im) + z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im) end @test isapprox(f(y), 2.0 - 1.0im)