Skip to content

Commit

Permalink
Generalize broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 29, 2024
1 parent 49eaacb commit 5134fe0
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 30 deletions.
17 changes: 10 additions & 7 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
# 2. Ensure that an appropriate LLVM_full_jll is installed
Pkg.activate(; temp=true)

# Build!
@info "Building" source_dir scratch_dir
run(`mkdir -p $(scratch_dir)`)
run(Cmd(`$(Base.julia_cmd().exec[1]) --project=. -e "using Pkg; Pkg.instantiate()"`, dir=source_dir))
# --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
run(Cmd(`bazel build -c dbg --action_env=JULIA=$(Base.julia_cmd().exec[1])
cuda = """
--repo_env TF_NEED_CUDA=1
--repo_env TF_DOWNLOAD_CLANG=1
--repo_env TF_CUDA_PATHS="/usr/local/cuda"
Expand All @@ -39,8 +34,16 @@ run(Cmd(`bazel build -c dbg --action_env=JULIA=$(Base.julia_cmd().exec[1])
--@xla//xla/python:enable_gpu=true
--@xla//xla/python:jax_cuda_pip_rpaths=true
--define=xla_python_enable_gpu=true
"""

# Build!
@info "Building" source_dir scratch_dir
run(`mkdir -p $(scratch_dir)`)
run(Cmd(`$(Base.julia_cmd().exec[1]) --project=. -e "using Pkg; Pkg.instantiate()"`, dir=source_dir))
# --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
run(Cmd(`bazel build -c dbg --action_env=JULIA=$(Base.julia_cmd().exec[1])
--check_visibility=false --verbose_failures :libReactantExtra.so :Builtin.inc.jl :Arith.inc.jl :Affine.inc.jl :Func.inc.jl :Enzyme.inc.jl :StableHLO.inc.jl :CHLO.inc.jl :VHLO.inc.jl`, dir=source_dir,
env=Dict("PATH"=>joinpath(source_dir, "..")*":"*ENV["PATH"])))
env=Dict("PATH"=>joinpath(source_dir, "..")*":"*ENV["PATH"], "HOME"=>ENV["HOME"])))

run(Cmd(`rm -f libReactantExtra.dylib`, dir=joinpath(source_dir, "bazel-bin")))
run(Cmd(`ln -s libReactantExtra.so libReactantExtra.dylib`, dir=joinpath(source_dir, "bazel-bin")))
Expand Down
44 changes: 25 additions & 19 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ function append_path(path, i)
return (path..., i)
end

@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data) where {RT}
@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data; toscalar=false, tobatch=nothing) where {RT}
if haskey(seen, prev)
return seen[prev]
end
Expand All @@ -443,7 +443,7 @@ end
subs = []
for i in 1:nf
xi = Base.getfield(prev, i)
xi2 = make_tracer(seen, xi, append_path(path, i), mode, data)
xi2 = make_tracer(seen, xi, append_path(path, i), mode, data; toscalar, tobatch)
if xi !== xi2
changed = true
end
Expand All @@ -465,7 +465,7 @@ end
for i in 1:nf
if isdefined(prev, i)
xi = Base.getfield(prev, i)
xi2 = make_tracer(seen, xi, append_path(path, i), mode, data)
xi2 = make_tracer(seen, xi, append_path(path, i), mode, data; toscalar, tobatch)
if xi !== xi2
changed = true
end
Expand All @@ -488,7 +488,7 @@ end
for i in 1:nf
if isdefined(prev, i)
xi = Base.getfield(prev, i)
xi2 = make_tracer(seen, xi, append_path(path, i), mode, data)
xi2 = make_tracer(seen, xi, append_path(path, i), mode, data; toscalar, tobatch)
if xi !== xi2
changed = true
end
Expand All @@ -508,7 +508,7 @@ end
end

@inline function make_tracer(
seen::IdDict, prev::ConcreteRArray{ElType,Shape,N}, path, mode, data
seen::IdDict, prev::ConcreteRArray{ElType,Shape,N}, path, mode, data; toscalar=false, tobatch=nothing
) where {ElType,Shape,N}
if mode == ArrayToConcrete
return prev
Expand All @@ -526,7 +526,7 @@ end
end

@inline function make_tracer(
seen::IdDict, prev::TracedRArray{ElType,Shape,N}, path, mode, data
seen::IdDict, prev::TracedRArray{ElType,Shape,N}, path, mode, data; toscalar=false, tobatch=nothing
) where {ElType,Shape,N}
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
Expand All @@ -542,7 +542,13 @@ end
if haskey(seen, prev)
return seen[prev]
end
res = TracedRArray{ElType,Shape,N}((path,), prev.mlir_data)
res = if toscalar
TracedRArray{ElType,(),0}((path,), nothing)
elseif tobatch !== nothing
TracedRArray{ElType,tobatch,length(tobatch)}((path,), prev.mlir_data)
else
TracedRArray{ElType,Shape,N}((path,), prev.mlir_data)
end
seen[prev] = res
return res
end
Expand All @@ -560,19 +566,19 @@ end
end

@inline function make_tracer(
seen::IdDict, prev::RT, path, mode, data
seen::IdDict, prev::RT, path, mode, data; toscalar=true, tobatch=nothing
) where {RT<:AbstractFloat}
return prev
end

@inline function make_tracer(seen::IdDict, prev::Complex{RT}, path, mode, data) where {RT}
@inline function make_tracer(seen::IdDict, prev::Complex{RT}, path, mode, data; toscalar=false, tobatch=nothing) where {RT}
return Complex(
make_tracer(seen, prev.re, append_path(path, :re), mode, data),
make_tracer(seen, prev.im, append_path(path, :im), mode, data),
make_tracer(seen, prev.re, append_path(path, :re), mode, data; toscalar, tobatch),
make_tracer(seen, prev.im, append_path(path, :im), mode, data; toscalar, tobatch),
)
end

@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data) where {RT<:Array}
@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data; toscalar=false, tobatch=nothing) where {RT<:Array}
if haskey(seen, prev)
return seen[prev]
end
Expand All @@ -586,7 +592,7 @@ end
for I in eachindex(prev)
if isassigned(prev, I)
pv = prev[I]
nv = make_tracer(seen, pv, append_path(path, I), mode, data)
nv = make_tracer(seen, pv, append_path(path, I), mode, data; toscalar, tobatch)
if pv !== nv
same = false
end
Expand All @@ -600,33 +606,33 @@ end
return newa
end

@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data) where {RT<:Tuple}
@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data; toscalar=false, tobatch=nothing) where {RT<:Tuple}
return (
(
make_tracer(seen, v, append_path(path, i), mode, data) for
make_tracer(seen, v, append_path(path, i), mode, data; toscalar, tobatch) for
(i, v) in enumerate(prev)
)...,
)
end

@inline function make_tracer(
seen::IdDict, prev::NamedTuple{A,RT}, path, mode, data
seen::IdDict, prev::NamedTuple{A,RT}, path, mode, data; toscalar=false, tobatch=nothing
) where {A,RT}
return NamedTuple{A,traced_type(RT, (), Val(mode))}((
(
make_tracer(
seen, Base.getfield(prev, name), append_path(path, name), mode, data
seen, Base.getfield(prev, name), append_path(path, name), mode, data; toscalar, tobatch
) for name in A
)...,
))
end

@inline function make_tracer(seen::IdDict, prev::Core.Box, path, mode, data)
@inline function make_tracer(seen::IdDict, prev::Core.Box, path, mode, data; toscalar=false, tobatch=nothing)
if haskey(seen, prev)
return seen[prev]
end
prev2 = prev.contents
tr = make_tracer(seen, prev2, append_path(path, :contents), mode, data)
tr = make_tracer(seen, prev2, append_path(path, :contents), mode, data; toscalar, tobatch)
if tr == prev2
seen[prev] = prev
return prev
Expand Down
89 changes: 89 additions & 0 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,95 @@ for (jlop, hloop) in (
end
end


function elem_apply(f, args::VarArgs{Nargs})
primf = f.val

mod = MLIR.IR.module()

fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
mod, f, args, (), string(f) * "_broadcast_scalar", false
)

invmap = IdDict()
OutShape = nothing
for (k, v) in seen_args
invmap[v] = k
OutShape size(k)
end
@assert OutShape !== nothing
in_tys2 = [mlir_type(invmap[arg]) for arg in linear_args]

out_tys2 = [MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results]

function act_attr(val)
val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet(
MLIR.IR.context()::MLIR.API.MlirContext, val::Int32
)::MLIR.API.MlirAttribute
return MLIR.IR.Attribute(val)
end

fname = get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))

batch_inputs = MLIR.IR.Value[]

for a in linear_args
idx, path = get_argidx(a)
if idx == 1 && fnwrap
push_acts!(batch_inputs, f, path[3:end], reverse)
else
if fnwrap
idx -= 1
end
push_acts!(batch_inputs, args[idx], path[3:end], reverse)
end
end


function act_attr(val)
val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet(
MLIR.IR.context()::MLIR.API.MlirContext, val::Int32
)::MLIR.API.MlirAttribute
return MLIR.IR.Attribute(val)
end

res = MLIR.Dialects.enzyme.batch(
batch_inputs;
outputs=outtys,
fn=fname,
batch_sizes=DenseArrayAttribute([Int64(i) for i in Shape]),
)

residx = 1

for a in linear_results
if has_residx(a)
path = get_residx(a)
set!(result, path[2:end], MLIR.IR.result(res, residx))
residx += 1
else
idx, path = get_argidx(a)
if idx == 1 && fnwrap
set!(f, path[3:end], MLIR.IR.result(res, residx))
residx += 1
else
if fnwrap
idx -= 1
end
set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
residx += 1
end
end
end

traced2_result = make_tracer(
seen_results, result, (), TracedSetPath, nothing; tobatch=OutShape
) #=data=#

return traced2_result
end

for (jlop, hloop, RT) in (
(:(Base.min), :minimum, :ElType),
(:(Base.max), :maximum, :ElType),
Expand Down
11 changes: 7 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function apply(f, args...; kwargs...)
return f(args...; kwargs...)
end

function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true)
function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true; toscalar=false)
if sizeof(typeof(f)) != 0
return (
true, make_mlir_fn(mod, apply, (f, args...), kwargs, name, concretein)[2:end]...
Expand All @@ -29,8 +29,7 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true)
args[i],
("args", i),
concretein ? ConcreteToTraced : TracedSetPath,
nothing,
) #=data=#
nothing; toscalar)
end

linear_args = TracedRArray[]
Expand All @@ -41,7 +40,11 @@ function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true)
push!(linear_args, v)
end

in_tys = [transpose_ty(mlir_type(arg)) for arg in linear_args]
in_tys = if traced_scalar
[MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args]
else
[transpose_ty(mlir_type(arg)) for arg in linear_args]
end

sym_visibility = nothing
if !concretein
Expand Down

0 comments on commit 5134fe0

Please sign in to comment.