Skip to content

Commit

Permalink
feat: add dynamic_update_slice_const_prop pass + tril + triu (#334
Browse files Browse the repository at this point in the history
)

* feat: add dynamic_update_slice_const_prop pass

* refactor: move linear algebra overloads to a different file

* feat: add triu and tril impl

* refactor: minimize batch_op

* feat: add Ops.compare

* refactor: use ops in base dispatches

* refactor: move linear algebra tests

* fix: tril defn and inplace ops

* test: add inplace tests
  • Loading branch information
avik-pal authored Dec 7, 2024
1 parent 1bb0000 commit 3a0710d
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 119 deletions.
15 changes: 0 additions & 15 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,6 @@ function NNlib.pad_constant(
return TracedRArray{T,N}((), res, size(MLIR.IR.type(res)))
end

function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
len = size(x, dims)
# directly generating booleans were causing an incorrect constant attribute generation
# but the optimized IR removes the type case so we are probably ok
mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))))
return Reactant.promote_to(
TracedRArray{Bool,2},
TracedRArray{Int,2}(
(),
MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=mask), 1),
(len, len),
),
)
end

# XXX: reevaluate this manual optimization once
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled
function NNlib.gather!(
Expand Down
1 change: 1 addition & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ const opt_passes::String = join(
"pad_dot_general<1>(1)",
"if_inline<1>",
"if_to_select<1>",
"dynamic_update_slice_const_prop",
],
';',
) *
Expand Down
32 changes: 32 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1014,4 +1014,36 @@ function select(
return TracedRNumber{T}((), res)
end

# comparison
function compare(
lhs::Union{TracedRArray{T},TracedRNumber{T}},
rhs::Union{TracedRArray{T},TracedRNumber{T}};
comparison_direction::String,
compare_type=nothing,
location=mlir_stacktrace("compare", @__FILE__, @__LINE__),
) where {T}
@assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT")
@assert size(lhs) == size(rhs)
if lhs isa TracedRNumber
@assert rhs isa TracedRNumber
else
@assert rhs isa TracedRArray
end

res = MLIR.IR.result(
MLIR.Dialects.stablehlo.compare(
lhs.mlir_data,
rhs.mlir_data;
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
MLIR.IR.context(), comparison_direction
),
compare_type,
location,
),
1,
)
lhs isa TracedRNumber && return TracedRNumber{Bool}((), res)
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
end

end
3 changes: 3 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ include("utils.jl")
include("ConcreteRArray.jl")
include("TracedRNumber.jl")
include("TracedRArray.jl")

include("Ops.jl")

include("linear_algebra.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

include("ControlFlow.jl")
Expand Down
89 changes: 0 additions & 89 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,95 +415,6 @@ for (jlop, hloop, hlocomp, merge) in
end
end

function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,1}),
@nospecialize(A::AnyTracedRArray{T2,2}),
@nospecialize(B::AnyTracedRArray{T3,1}),
α::Number=true,
β::Number=false,
) where {T1,T2,T3}
# TODO: The reshape operations are not getting optimized, we should directly call dot_general
rC = reshape(C, :, 1)
LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β)
C.mlir_data = get_mlir_data(vec(rC))
return C
end

function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,2}),
@nospecialize(A::AnyTracedRArray{T2,2}),
@nospecialize(B::AnyTracedRArray{T3,1}),
α::Number=true,
β::Number=false,
) where {T1,T2,T3}
LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β)
return C
end

function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,2}),
@nospecialize(A::AnyTracedRArray{T2,2}),
@nospecialize(B::AnyTracedRArray{T3,2}),
α::Number=true,
β::Number=false,
) where {T1,T2,T3}
if size(C) != (size(A, 1), size(B, 2))
throw(
DimensionMismatch(
"C has size $(size(C)), A has size $(size(A)), B has size $(size(B))"
),
)
end
if size(A, 2) != size(B, 1)
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
end
resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1))
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0]
)
prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
precar = MLIR.IR.Attribute([prec, prec])
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
get_mlir_data(A),
get_mlir_data(B);
result_0=resty,
dot_dimension_numbers=dot_dimension_numbers,
precision_config=precar,
),
1,
)
if iszero(β)
if isone(α)
C.mlir_data = res
else
C.mlir_data = MLIR.IR.result(
MLIR.Dialects.stablehlo.multiply(
res, broadcast_to_size(T1(α), size(C)).mlir_data
),
1,
)
end
else
α_res = MLIR.IR.result(
MLIR.Dialects.stablehlo.multiply(
res, broadcast_to_size(T1(α), size(C)).mlir_data
),
1,
)
β_C = MLIR.IR.result(
MLIR.Dialects.stablehlo.multiply(
C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data
),
1,
)
C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1)
end
return C
end

function Enzyme.Compiler.active_reg_inner(
::Type{TracedRArray{T,N}},
seen::ST,
Expand Down
14 changes: 1 addition & 13 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,7 @@ for (jlop, hloop, hlocomp) in (
function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
) where {T}
return TracedRNumber{Bool}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$(hloop)(
lhs.mlir_data,
rhs.mlir_data;
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
MLIR.IR.context(), $hlocomp
),
),
1,
),
)
return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp))
end

function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}
Expand Down
108 changes: 108 additions & 0 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,1}),
@nospecialize(A::AnyTracedRArray{T2,2}),
@nospecialize(B::AnyTracedRArray{T3,1}),
α::Number=true,
β::Number=false,
) where {T1,T2,T3}
# TODO: The reshape operations are not getting optimized, we should directly call dot_general
rC = reshape(C, :, 1)
LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β)
C.mlir_data = get_mlir_data(vec(rC))
return C
end

function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,2}),
@nospecialize(A::AnyTracedRArray{T2,2}),
@nospecialize(B::AnyTracedRArray{T3,1}),
α::Number=true,
β::Number=false,
) where {T1,T2,T3}
LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β)
return C
end

function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,2}),
@nospecialize(A::AnyTracedRArray{T2,2}),
@nospecialize(B::AnyTracedRArray{T3,2}),
α::Number=true,
β::Number=false,
) where {T1,T2,T3}
if size(C) != (size(A, 1), size(B, 2))
throw(
DimensionMismatch(
"C has size $(size(C)), A has size $(size(A)), B has size $(size(B))"
),
)
end
if size(A, 2) != size(B, 1)
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
end
resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1))
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0]
)
prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
precar = MLIR.IR.Attribute([prec, prec])
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
get_mlir_data(A),
get_mlir_data(B);
result_0=resty,
dot_dimension_numbers=dot_dimension_numbers,
precision_config=precar,
),
1,
)
if iszero(β)
if isone(α)
C.mlir_data = res
else
C.mlir_data = MLIR.IR.result(
MLIR.Dialects.stablehlo.multiply(
res, broadcast_to_size(T1(α), size(C)).mlir_data
),
1,
)
end
else
α_res = MLIR.IR.result(
MLIR.Dialects.stablehlo.multiply(
res, broadcast_to_size(T1(α), size(C)).mlir_data
),
1,
)
β_C = MLIR.IR.result(
MLIR.Dialects.stablehlo.multiply(
C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data
),
1,
)
C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1)
end
return C
end

function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1)
iota_2 = Ops.subtract(
Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X))
)
idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE")
X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data
return X
end

function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1)
iota_2 = Ops.subtract(
Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X))
)
idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE")
X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data
return X
end
38 changes: 37 additions & 1 deletion test/linear_algebra.jl → test/integration/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function mul_with_view3(A, x)
return C
end

@testset begin
@testset "Matrix Multiplication" begin
A = rand(4, 4)
x = rand(4, 2)
b = rand(4)
Expand Down Expand Up @@ -77,3 +77,39 @@ end
@jit(mul!(C_ra, A_ra, x_ra))
@test C_ra A * x
end

@testset "triu & tril" begin
A = rand(4, 6)
A_ra = Reactant.to_rarray(A)

@test @jit(triu(A_ra)) triu(A)
@test @jit(tril(A_ra)) tril(A)
@test @jit(triu(A_ra, 2)) triu(A, 2)
@test @jit(tril(A_ra, 2)) tril(A, 2)
@test @jit(triu(A_ra, -1)) triu(A, -1)
@test @jit(tril(A_ra, -1)) tril(A, -1)

A_ra = Reactant.to_rarray(A)
@jit(triu!(A_ra))
@test A_ra triu(A)

A_ra = Reactant.to_rarray(A)
@jit(tril!(A_ra))
@test A_ra tril(A)

A_ra = Reactant.to_rarray(A)
@jit(triu!(A_ra, 2))
@test A_ra triu(A, 2)

A_ra = Reactant.to_rarray(A)
@jit(tril!(A_ra, 2))
@test A_ra tril(A, 2)

A_ra = Reactant.to_rarray(A)
@jit(triu!(A_ra, -1))
@test A_ra triu(A, -1)

A_ra = Reactant.to_rarray(A)
@jit(tril!(A_ra, -1))
@test A_ra tril(A, -1)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Linear Algebra" include("linear_algebra.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
@safetestset "AbstractFFTs" include("integration/fft.jl")
end

Expand Down

0 comments on commit 3a0710d

Please sign in to comment.