From 9b0207a54934f07445c8a75ea8dd27f5460b8f8a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:17:57 +0200 Subject: [PATCH] Move zero backends to DI itself --- .../src/DifferentiationInterface.jl | 1 + .../src/misc}/zero_backends.jl | 32 ++++++++--------- .../{fromprimitive.jl => from_primitive.jl} | 11 ------ .../test/Internals/zero_backends.jl | 36 +++++++++++++++++++ .../src/DifferentiationInterfaceTest.jl | 1 - DifferentiationInterfaceTest/test/runtests.jl | 2 +- .../test/{zero.jl => zero_backends.jl} | 26 +++++--------- 7 files changed, 63 insertions(+), 46 deletions(-) rename {DifferentiationInterfaceTest/src/utils => DifferentiationInterface/src/misc}/zero_backends.jl (70%) rename DifferentiationInterface/test/Internals/{fromprimitive.jl => from_primitive.jl} (78%) create mode 100644 DifferentiationInterface/test/Internals/zero_backends.jl rename DifferentiationInterfaceTest/test/{zero.jl => zero_backends.jl} (81%) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index f21c39c7e..7b155320a 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -77,6 +77,7 @@ include("sparse/hessian.jl") include("misc/differentiate_with.jl") include("misc/sparsity_detector.jl") include("misc/from_primitive.jl") +include("misc/zero_backends.jl") function __init__() @require_extensions diff --git a/DifferentiationInterfaceTest/src/utils/zero_backends.jl b/DifferentiationInterface/src/misc/zero_backends.jl similarity index 70% rename from DifferentiationInterfaceTest/src/utils/zero_backends.jl rename to DifferentiationInterface/src/misc/zero_backends.jl index 0a4d9a056..d30289203 100644 --- a/DifferentiationInterfaceTest/src/utils/zero_backends.jl +++ b/DifferentiationInterface/src/misc/zero_backends.jl @@ -11,13 +11,13 @@ Used in testing and benchmarking. struct AutoZeroForward <: AbstractADType end ADTypes.mode(::AutoZeroForward) = ForwardMode() -DI.check_available(::AutoZeroForward) = true -DI.twoarg_support(::AutoZeroForward) = DI.TwoArgSupported() +check_available(::AutoZeroForward) = true +twoarg_support(::AutoZeroForward) = TwoArgSupported() -DI.prepare_pushforward(f, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras() -DI.prepare_pushforward(f!, y, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras() +prepare_pushforward(f, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras() +prepare_pushforward(f!, y, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras() -function DI.value_and_pushforward( +function value_and_pushforward( f, ::AutoZeroForward, x, tx::Tangents{B}, ::NoPushforwardExtras ) where {B} y = f(x) @@ -25,7 +25,7 @@ function DI.value_and_pushforward( return y, Tangents(dys) end -function DI.value_and_pushforward( +function value_and_pushforward( f!, y, ::AutoZeroForward, x, tx::Tangents{B}, ::NoPushforwardExtras ) where {B} f!(y, x) @@ -33,7 +33,7 @@ function DI.value_and_pushforward( return y, Tangents(dys) end -function DI.value_and_pushforward!( +function value_and_pushforward!( f, ty::Tangents, ::AutoZeroForward, x, tx::Tangents, ::NoPushforwardExtras ) y = f(x) @@ -43,7 +43,7 @@ function DI.value_and_pushforward!( return y, ty end -function DI.value_and_pushforward!( +function value_and_pushforward!( f!, y, ty::Tangents, ::AutoZeroForward, x, tx::Tangents, ::NoPushforwardExtras ) f!(y, x) @@ -64,13 +64,13 @@ Used in testing and benchmarking. struct AutoZeroReverse <: AbstractADType end ADTypes.mode(::AutoZeroReverse) = ReverseMode() -DI.check_available(::AutoZeroReverse) = true -DI.twoarg_support(::AutoZeroReverse) = DI.TwoArgSupported() +check_available(::AutoZeroReverse) = true +twoarg_support(::AutoZeroReverse) = TwoArgSupported() -DI.prepare_pullback(f, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras() -DI.prepare_pullback(f!, y, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras() +prepare_pullback(f, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras() +prepare_pullback(f!, y, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras() -function DI.value_and_pullback( +function value_and_pullback( f, ::AutoZeroReverse, x, ty::Tangents{B}, ::NoPullbackExtras ) where {B} y = f(x) @@ -78,7 +78,7 @@ function DI.value_and_pullback( return y, Tangents(dxs) end -function DI.value_and_pullback( +function value_and_pullback( f!, y, ::AutoZeroReverse, x, ty::Tangents{B}, ::NoPullbackExtras ) where {B} f!(y, x) @@ -86,7 +86,7 @@ function DI.value_and_pullback( return y, Tangents(dxs) end -function DI.value_and_pullback!( +function value_and_pullback!( f, tx::Tangents, ::AutoZeroReverse, x, ty::Tangents, ::NoPullbackExtras ) y = f(x) @@ -96,7 +96,7 @@ function DI.value_and_pullback!( return y, tx end -function DI.value_and_pullback!( +function value_and_pullback!( f!, y, tx::Tangents, ::AutoZeroReverse, x, ty::Tangents, ::NoPullbackExtras ) f!(y, x) diff --git a/DifferentiationInterface/test/Internals/fromprimitive.jl b/DifferentiationInterface/test/Internals/from_primitive.jl similarity index 78% rename from DifferentiationInterface/test/Internals/fromprimitive.jl rename to DifferentiationInterface/test/Internals/from_primitive.jl index 4d80270d2..7ef9c2884 100644 --- a/DifferentiationInterface/test/Internals/fromprimitive.jl +++ b/DifferentiationInterface/test/Internals/from_primitive.jl @@ -17,15 +17,4 @@ for backend in vcat(fromprimitive_backends) @test DifferentiationInterface.pick_batchsize(backend, 100) == 5 end -## Dense backends - test_differentiation(fromprimitive_backends, default_scenarios(); logging=LOGGING); - -test_differentiation( - fromprimitive_backends[1], - default_scenarios(); - correctness=false, - type_stability=true, - second_order=false, - logging=LOGGING, -); diff --git a/DifferentiationInterface/test/Internals/zero_backends.jl b/DifferentiationInterface/test/Internals/zero_backends.jl new file mode 100644 index 000000000..6d55eb48f --- /dev/null +++ b/DifferentiationInterface/test/Internals/zero_backends.jl @@ -0,0 +1,36 @@ +using DifferentiationInterface +using DifferentiationInterface: AutoZeroForward, AutoZeroReverse +using DifferentiationInterfaceTest +using Test + +LOGGING = get(ENV, "CI", "false") == "false" + +zero_backends = [AutoZeroForward(), AutoZeroReverse()] + +for backend in zero_backends + @test check_available(backend) + @test check_twoarg(backend) +end + +## Type stability + +test_differentiation( + zero_backends, + default_scenarios(); + correctness=false, + type_stability=true, + excluded=[:second_derivative], + logging=LOGGING, +) + +test_differentiation( + [ + SecondOrder(AutoZeroForward(), AutoZeroReverse()), + SecondOrder(AutoZeroReverse(), AutoZeroForward()), + ], + default_scenarios(; linalg=false); + correctness=false, + type_stability=true, + first_order=false, + logging=LOGGING, +) diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 7d967bedf..028114893 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -60,7 +60,6 @@ include("scenarios/sparse.jl") include("scenarios/allocfree.jl") include("scenarios/extensions.jl") -include("utils/zero_backends.jl") include("utils/misc.jl") include("utils/filter.jl") diff --git a/DifferentiationInterfaceTest/test/runtests.jl b/DifferentiationInterfaceTest/test/runtests.jl index d3b1aec82..3fb338813 100644 --- a/DifferentiationInterfaceTest/test/runtests.jl +++ b/DifferentiationInterfaceTest/test/runtests.jl @@ -18,7 +18,7 @@ GROUP = get(ENV, "JULIA_DIT_TEST_GROUP", "All") if GROUP == "Zero" || GROUP == "All" @testset verbose = false "Zero" begin - include("zero.jl") + include("zero_backends.jl") end end diff --git a/DifferentiationInterfaceTest/test/zero.jl b/DifferentiationInterfaceTest/test/zero_backends.jl similarity index 81% rename from DifferentiationInterfaceTest/test/zero.jl rename to DifferentiationInterfaceTest/test/zero_backends.jl index b80674de4..193d22e0e 100644 --- a/DifferentiationInterfaceTest/test/zero.jl +++ b/DifferentiationInterfaceTest/test/zero_backends.jl @@ -1,12 +1,8 @@ using DifferentiationInterface +using DifferentiationInterface: AutoZeroForward, AutoZeroReverse using DifferentiationInterfaceTest using DifferentiationInterfaceTest: - AutoZeroForward, - AutoZeroReverse, - scenario_to_zero, - test_allocfree, - allocfree_scenarios, - remove_batched + scenario_to_zero, test_allocfree, allocfree_scenarios, remove_batched using ComponentArrays: ComponentArrays using JLArrays: JLArrays using StaticArrays: StaticArrays @@ -15,18 +11,14 @@ using Test LOGGING = get(ENV, "CI", "false") == "false" -@test check_available(AutoZeroForward()) -@test check_available(AutoZeroReverse()) -@test check_twoarg(AutoZeroForward()) -@test check_twoarg(AutoZeroReverse()) - ## Correctness + type stability test_differentiation( [AutoZeroForward(), AutoZeroReverse()], - scenario_to_zero.(default_scenarios()); - correctness=true, - type_stability=false, # TODO: switch back + default_scenarios(); + correctness=false, + type_stability=true, + excluded=[:second_derivative], logging=LOGGING, ) @@ -35,9 +27,9 @@ test_differentiation( SecondOrder(AutoZeroForward(), AutoZeroReverse()), SecondOrder(AutoZeroReverse(), AutoZeroForward()), ], - scenario_to_zero.(default_scenarios(; linalg=false)); - correctness=true, - type_stability=false, # TODO: switch back + default_scenarios(; linalg=false); + correctness=false, + type_stability=true, first_order=false, logging=LOGGING, )