From 3a0710d283645349b53898a77ea5e6e335be5f0c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 12:26:24 +0530 Subject: [PATCH] feat: add `dynamic_update_slice_const_prop` pass + `tril` + `triu` (#334) * feat: add dynamic_update_slice_const_prop pass * refactor: move linear algebra overloads to a different file * feat: add triu and tril impl * refactor: minimize batch_op * feat: add Ops.compare * refactor: use ops in base dispatches * refactor: move linear algebra tests * fix: tril defn and inplace ops * test: add inplace tests --- ext/ReactantNNlibExt.jl | 15 ---- src/Compiler.jl | 1 + src/Ops.jl | 32 +++++++ src/Reactant.jl | 3 + src/TracedRArray.jl | 89 ------------------- src/TracedRNumber.jl | 14 +-- src/linear_algebra.jl | 108 +++++++++++++++++++++++ test/{ => integration}/linear_algebra.jl | 38 +++++++- test/runtests.jl | 2 +- 9 files changed, 183 insertions(+), 119 deletions(-) create mode 100644 src/linear_algebra.jl rename test/{ => integration}/linear_algebra.jl (68%) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index b78716d29..b90fa60fb 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -298,21 +298,6 @@ function NNlib.pad_constant( return TracedRArray{T,N}((), res, size(MLIR.IR.type(res))) end -function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) - len = size(x, dims) - # directly generating booleans were causing an incorrect constant attribute generation - # but the optimized IR removes the type case so we are probably ok - mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len))))) - return Reactant.promote_to( - TracedRArray{Bool,2}, - TracedRArray{Int,2}( - (), - MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=mask), 1), - (len, len), - ), - ) -end - # XXX: reevaluate this manual optimization once # https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled function NNlib.gather!( diff --git a/src/Compiler.jl b/src/Compiler.jl index 0bd3eaa4f..586f33b05 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -245,6 +245,7 @@ const opt_passes::String = join( "pad_dot_general<1>(1)", "if_inline<1>", "if_to_select<1>", + "dynamic_update_slice_const_prop", ], ';', ) * diff --git a/src/Ops.jl b/src/Ops.jl index e9f3d3252..2148cb5eb 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1014,4 +1014,36 @@ function select( return TracedRNumber{T}((), res) end +# comparison +function compare( + lhs::Union{TracedRArray{T},TracedRNumber{T}}, + rhs::Union{TracedRArray{T},TracedRNumber{T}}; + comparison_direction::String, + compare_type=nothing, + location=mlir_stacktrace("compare", @__FILE__, @__LINE__), +) where {T} + @assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT") + @assert size(lhs) == size(rhs) + if lhs isa TracedRNumber + @assert rhs isa TracedRNumber + else + @assert rhs isa TracedRArray + end + + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.compare( + lhs.mlir_data, + rhs.mlir_data; + comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( + MLIR.IR.context(), comparison_direction + ), + compare_type, + location, + ), + 1, + ) + lhs isa TracedRNumber && return TracedRNumber{Bool}((), res) + return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs)) +end + end diff --git a/src/Reactant.jl b/src/Reactant.jl index 036edfa56..0b73d3d96 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -98,8 +98,11 @@ include("utils.jl") include("ConcreteRArray.jl") include("TracedRNumber.jl") include("TracedRArray.jl") + include("Ops.jl") +include("linear_algebra.jl") + const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} include("ControlFlow.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 91d8df004..97a29e56f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -415,95 +415,6 @@ for (jlop, hloop, hlocomp, merge) in end end -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,1}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), - α::Number=true, - β::Number=false, -) where {T1,T2,T3} - # TODO: The reshape operations are not getting optimized, we should directly call dot_general - rC = reshape(C, :, 1) - LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) - C.mlir_data = get_mlir_data(vec(rC)) - return C -end - -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), - α::Number=true, - β::Number=false, -) where {T1,T2,T3} - LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) - return C -end - -function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,2}), - α::Number=true, - β::Number=false, -) where {T1,T2,T3} - if size(C) != (size(A, 1), size(B, 2)) - throw( - DimensionMismatch( - "C has size $(size(C)), A has size $(size(A)), B has size $(size(B))" - ), - ) - end - if size(A, 2) != size(B, 1) - throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) - end - resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1)) - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( - MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] - ) - prec = MLIR.IR.Attribute( - MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") - ) - precar = MLIR.IR.Attribute([prec, prec]) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dot_general( - get_mlir_data(A), - get_mlir_data(B); - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=precar, - ), - 1, - ) - if iszero(β) - if isone(α) - C.mlir_data = res - else - C.mlir_data = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - end - else - α_res = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - β_C = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data - ), - 1, - ) - C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1) - end - return C -end - function Enzyme.Compiler.active_reg_inner( ::Type{TracedRArray{T,N}}, seen::ST, diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9780faf9d..12c5123b7 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -151,19 +151,7 @@ for (jlop, hloop, hlocomp) in ( function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return TracedRNumber{Bool}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$(hloop)( - lhs.mlir_data, - rhs.mlir_data; - comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet( - MLIR.IR.context(), $hlocomp - ), - ), - 1, - ), - ) + return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp)) end function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl new file mode 100644 index 000000000..c7e72651d --- /dev/null +++ b/src/linear_algebra.jl @@ -0,0 +1,108 @@ +function LinearAlgebra.mul!( + @nospecialize(C::TracedRArray{T1,1}), + @nospecialize(A::AnyTracedRArray{T2,2}), + @nospecialize(B::AnyTracedRArray{T3,1}), + α::Number=true, + β::Number=false, +) where {T1,T2,T3} + # TODO: The reshape operations are not getting optimized, we should directly call dot_general + rC = reshape(C, :, 1) + LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) + C.mlir_data = get_mlir_data(vec(rC)) + return C +end + +function LinearAlgebra.mul!( + @nospecialize(C::TracedRArray{T1,2}), + @nospecialize(A::AnyTracedRArray{T2,2}), + @nospecialize(B::AnyTracedRArray{T3,1}), + α::Number=true, + β::Number=false, +) where {T1,T2,T3} + LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) + return C +end + +function LinearAlgebra.mul!( + @nospecialize(C::TracedRArray{T1,2}), + @nospecialize(A::AnyTracedRArray{T2,2}), + @nospecialize(B::AnyTracedRArray{T3,2}), + α::Number=true, + β::Number=false, +) where {T1,T2,T3} + if size(C) != (size(A, 1), size(B, 2)) + throw( + DimensionMismatch( + "C has size $(size(C)), A has size $(size(A)), B has size $(size(B))" + ), + ) + end + if size(A, 2) != size(B, 1) + throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) + end + resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1)) + dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( + MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] + ) + prec = MLIR.IR.Attribute( + MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") + ) + precar = MLIR.IR.Attribute([prec, prec]) + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.dot_general( + get_mlir_data(A), + get_mlir_data(B); + result_0=resty, + dot_dimension_numbers=dot_dimension_numbers, + precision_config=precar, + ), + 1, + ) + if iszero(β) + if isone(α) + C.mlir_data = res + else + C.mlir_data = MLIR.IR.result( + MLIR.Dialects.stablehlo.multiply( + res, broadcast_to_size(T1(α), size(C)).mlir_data + ), + 1, + ) + end + else + α_res = MLIR.IR.result( + MLIR.Dialects.stablehlo.multiply( + res, broadcast_to_size(T1(α), size(C)).mlir_data + ), + 1, + ) + β_C = MLIR.IR.result( + MLIR.Dialects.stablehlo.multiply( + C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data + ), + 1, + ) + C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1) + end + return C +end + +function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} + iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) + iota_2 = Ops.subtract( + Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + ) + idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE") + X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data + return X +end + +function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} + iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) + iota_2 = Ops.subtract( + Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + ) + idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE") + X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data + return X +end diff --git a/test/linear_algebra.jl b/test/integration/linear_algebra.jl similarity index 68% rename from test/linear_algebra.jl rename to test/integration/linear_algebra.jl index 1b35c6483..22fe07c1f 100644 --- a/test/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -45,7 +45,7 @@ function mul_with_view3(A, x) return C end -@testset begin +@testset "Matrix Multiplication" begin A = rand(4, 4) x = rand(4, 2) b = rand(4) @@ -77,3 +77,39 @@ end @jit(mul!(C_ra, A_ra, x_ra)) @test C_ra ≈ A * x end + +@testset "triu & tril" begin + A = rand(4, 6) + A_ra = Reactant.to_rarray(A) + + @test @jit(triu(A_ra)) ≈ triu(A) + @test @jit(tril(A_ra)) ≈ tril(A) + @test @jit(triu(A_ra, 2)) ≈ triu(A, 2) + @test @jit(tril(A_ra, 2)) ≈ tril(A, 2) + @test @jit(triu(A_ra, -1)) ≈ triu(A, -1) + @test @jit(tril(A_ra, -1)) ≈ tril(A, -1) + + A_ra = Reactant.to_rarray(A) + @jit(triu!(A_ra)) + @test A_ra ≈ triu(A) + + A_ra = Reactant.to_rarray(A) + @jit(tril!(A_ra)) + @test A_ra ≈ tril(A) + + A_ra = Reactant.to_rarray(A) + @jit(triu!(A_ra, 2)) + @test A_ra ≈ triu(A, 2) + + A_ra = Reactant.to_rarray(A) + @jit(tril!(A_ra, 2)) + @test A_ra ≈ tril(A, 2) + + A_ra = Reactant.to_rarray(A) + @jit(triu!(A_ra, -1)) + @test A_ra ≈ triu(A, -1) + + A_ra = Reactant.to_rarray(A) + @jit(tril!(A_ra, -1)) + @test A_ra ≈ tril(A, -1) +end diff --git a/test/runtests.jl b/test/runtests.jl index 34212e696..fddc963ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,10 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") - @safetestset "Linear Algebra" include("linear_algebra.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" + @safetestset "Linear Algebra" include("integration/linear_algebra.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") end