Skip to content

Commit

Permalink
Interp2 (#365)
Browse files Browse the repository at this point in the history
* WIP: kernels

* more files

* fix

* wip

* wqtmp

* wip

* inc

* continuing

* wip

* more work

* inf rec

* fix

* overload working

* continuing

* continuing

* push

* fix `call_with_reactant_generator` for Julia 1.11 (#359)

* conversion

* continuing

* Cleanup

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Delete test/cuda.jl

* fixup

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix apply

* indep of change

* minor fix in name

* Update utils.jl

* Interp take 2

* continuing adentures

* delcode

* fix

* tmp

* make

* fix

* cleanup

* continuing

* more working

* further simplify

* fx

* more improvements

* minus show

* less prints

* even fewer

* confusion

* tmp

* force clean

* force oc

* clean

* Rewrite

* fixup

* fix

* fix

* fix

* fixup

* fix

* wip

* safe prints

* fix

* fix

* stackoverflow

* cleanup

* dyindex

* rt

* continue

* clean

* fix

* fix

* fix

* fix

* fixup

* fix

* fix

* capture oc

* compile perf

* v1.11 fix

* other way 'round

* formatting

---------

Co-authored-by: William Moses <wsmoses@cyclops.juliacomputing.io>
Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: jumerckx <julesmerckx12@gmail.com>
  • Loading branch information
5 people authored Dec 14, 2024
1 parent 73899f5 commit 65e9976
Show file tree
Hide file tree
Showing 22 changed files with 1,578 additions and 960 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ Adapt = "4"
ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13.21"
Enzyme = "0.13.22"
EnzymeCore = "0.8.8"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
ReactantCore = "0.1.2"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.26"
Scratch = "1.2"
Statistics = "1.10"
Expand Down
10 changes: 10 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,16 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
return wrap(res);
}

#include "llvm/IRReader/IRReader.h"
extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) {
LLVMContext Context;
SMDiagnostic Err;
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
mlir::MLIRContext &context = *unwrap(cctx);
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release();
return wrap(res);
}


/* Note that this */
extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) {
Expand Down
2 changes: 2 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",

"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:AArch64AsmParser",
"@llvm-project//llvm:AArch64CodeGen",
Expand Down
27 changes: 11 additions & 16 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@ module ReactantNNlibExt

using NNlib
using GPUArraysCore: @allowscalar
using Reactant:
Reactant,
Ops,
TracedRArray,
AnyTracedRArray,
materialize_traced_array,
MLIR,
TracedRNumber,
get_mlir_data,
set_mlir_data!
using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber

using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data!

using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu

Expand Down Expand Up @@ -238,9 +232,9 @@ function NNlib.batched_mul!(
if size(x, 3) != size(y, 3)
B = max(size(x, 3), size(y, 3))
if size(x, 3) == 1
x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
elseif size(y, 3) == 1
y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
end
end

Expand All @@ -250,9 +244,9 @@ function NNlib.batched_mul!(
if size(x, 1) != size(y, 1)
B = max(size(x, 1), size(y, 1))
if size(x, 1) == 1
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
elseif size(y, 1) == 1
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
end
end

Expand All @@ -270,7 +264,7 @@ end
function NNlib.pad_constant(
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
) where {T,N}
value = Reactant.promote_to(TracedRNumber{T}, value)
value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value)
low = [i[1] for i in pad]
high = [i[2] for i in pad]
interior = [0 for i in pad]
Expand Down Expand Up @@ -329,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
start_sizes = ntuple(i -> size(src, i), dims)
results = map(CartesianIndices(idxs)) do k
res = @allowscalar src[colons..., Tuple(idxs[k])...]
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
res isa TracedRNumber &&
(res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
Expand Down
3 changes: 2 additions & 1 deletion ext/ReactantStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ReactantStatisticsExt

using Reactant: AnyTracedRArray, materialize_traced_array
using Reactant: AnyTracedRArray
using Reactant.TracedUtils: materialize_traced_array
using Statistics: Statistics

function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}
Expand Down
7 changes: 4 additions & 3 deletions ext/ReactantYaoBlocksExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module ReactantYaoBlocksExt

using Reactant
using Reactant.TracedUtils: broadcast_to_size
using YaoBlocks

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M = broadcast_to_size(zero(T), (2, 2))
c = cos(R.theta / 2)
s = -im * sin(R.theta / 2)
M[1, 1] = c
Expand All @@ -19,7 +20,7 @@ end
function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M = broadcast_to_size(zero(T), (2, 2))
c = cos(R.theta / 2)
s = sin(R.theta / 2)
M[1, 1] = c
Expand All @@ -32,7 +33,7 @@ end
function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M = broadcast_to_size(zero(T), (2, 2))
x = exp(im * R.theta / 2)
M[1, 1] = conj(x)
M[2, 2] = x
Expand Down
2 changes: 1 addition & 1 deletion lib/ReactantCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReactantCore"
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>"]
version = "0.1.2"
version = "0.1.3"

[deps]
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
Expand Down
6 changes: 4 additions & 2 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,17 @@ function trace_for(mod, expr)

all_syms = Expr(:tuple, counter, external_syms...)
args_init = Expr(
:tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
:tuple,
:(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)),
external_syms...,
)

reactant_code_block = quote
let args = $(args_init)
cond_fn =
$(all_syms) -> begin
local num_iters = div($limit - $start, $step, RoundDown)
local num_iters = Reactant.promote_to(
local num_iters = Reactant.TracedUtils.promote_to(
Reactant.TracedRNumber{Int64}, num_iters
)
$counter < num_iters + 1
Expand Down
9 changes: 8 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
return Reactant.make_mlir_fn(f, args, (), "main", true)
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
end

Expand Down Expand Up @@ -779,6 +779,13 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
return register_thunk(fname, body)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

Expand Down
10 changes: 9 additions & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ end
function Base.convert(
::Type{T}, X::WrappedConcreteRArray{ElType,N}
) where {T<:Array,ElType,N}
fn = compile(materialize_traced_array, (X,))
fn = compile(TracedUtils.materialize_traced_array, (X,))
return convert(Array, fn(X))
end
Base.Array(x::AnyConcreteRArray) = convert(Array, x)
Expand Down Expand Up @@ -345,3 +345,11 @@ end

buffer_on_cpu(::Any) = true
buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer)

function Ops.constant(x::ConcreteRArray; kwargs...)
return Ops.constant(Base.convert(Array, x); kwargs...)
end

function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T}
return Ops.constant(Base.convert(T, x); kwargs...)
end
16 changes: 8 additions & 8 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function ReactantCore.traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.make_mlir_fn(
(_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn(
true_fn,
args,
(),
Expand All @@ -12,7 +12,7 @@ function ReactantCore.traced_if(
construct_function_without_args=true,
)

(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.make_mlir_fn(
(_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn(
false_fn,
args,
(),
Expand All @@ -36,16 +36,16 @@ function ReactantCore.traced_if(
returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.")
elseif tr isa MissingTracedValue
push!(result_types, MLIR.IR.type(fr.mlir_data))
push!(linear_results, new_traced_value(false_linear_results[i]))
push!(linear_results, TracedUtils.new_traced_value(false_linear_results[i]))
push!(true_block_insertions, (i => linear_results[end]))
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, new_traced_value(true_linear_results[i]))
push!(linear_results, TracedUtils.new_traced_value(true_linear_results[i]))
push!(false_block_insertions, (i => linear_results[end]))
end
else
push!(result_types, MLIR.IR.type(tr.mlir_data))
push!(linear_results, new_traced_value(tr))
push!(linear_results, TracedUtils.new_traced_value(tr))
end
end

Expand Down Expand Up @@ -82,13 +82,13 @@ function ReactantCore.traced_while(
# We promote all incoming args (is there a better way to do this?)
traced_args = [
if v isa Number && !(v isa TracedType)
Reactant.promote_to(TracedRNumber{typeof(v)}, v)
Reactant.TracedUtils.promote_to(TracedRNumber{typeof(v)}, v)
else
v
end for v in args
]

(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn(
(_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
Expand All @@ -99,7 +99,7 @@ function ReactantCore.traced_while(
do_transpose=false,
)

(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn(
(_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
Expand Down
Loading

0 comments on commit 65e9976

Please sign in to comment.