Skip to content

Commit

Permalink
Move zero backends to DI itself
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Sep 1, 2024
1 parent 7c60378 commit 9b0207a
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 46 deletions.
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ include("sparse/hessian.jl")
include("misc/differentiate_with.jl")
include("misc/sparsity_detector.jl")
include("misc/from_primitive.jl")
include("misc/zero_backends.jl")

function __init__()
@require_extensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@ Used in testing and benchmarking.
struct AutoZeroForward <: AbstractADType end

ADTypes.mode(::AutoZeroForward) = ForwardMode()
DI.check_available(::AutoZeroForward) = true
DI.twoarg_support(::AutoZeroForward) = DI.TwoArgSupported()
check_available(::AutoZeroForward) = true
twoarg_support(::AutoZeroForward) = TwoArgSupported()

DI.prepare_pushforward(f, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()
DI.prepare_pushforward(f!, y, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()
prepare_pushforward(f, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()
prepare_pushforward(f!, y, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()

function DI.value_and_pushforward(
function value_and_pushforward(
f, ::AutoZeroForward, x, tx::Tangents{B}, ::NoPushforwardExtras
) where {B}
y = f(x)
dys = ntuple(Returns(zero(y)), Val(B))
return y, Tangents(dys)
end

function DI.value_and_pushforward(
function value_and_pushforward(
f!, y, ::AutoZeroForward, x, tx::Tangents{B}, ::NoPushforwardExtras
) where {B}
f!(y, x)
dys = ntuple(Returns(zero(y)), Val(B))
return y, Tangents(dys)
end

function DI.value_and_pushforward!(
function value_and_pushforward!(
f, ty::Tangents, ::AutoZeroForward, x, tx::Tangents, ::NoPushforwardExtras
)
y = f(x)
Expand All @@ -43,7 +43,7 @@ function DI.value_and_pushforward!(
return y, ty
end

function DI.value_and_pushforward!(
function value_and_pushforward!(
f!, y, ty::Tangents, ::AutoZeroForward, x, tx::Tangents, ::NoPushforwardExtras
)
f!(y, x)
Expand All @@ -64,29 +64,29 @@ Used in testing and benchmarking.
struct AutoZeroReverse <: AbstractADType end

ADTypes.mode(::AutoZeroReverse) = ReverseMode()
DI.check_available(::AutoZeroReverse) = true
DI.twoarg_support(::AutoZeroReverse) = DI.TwoArgSupported()
check_available(::AutoZeroReverse) = true
twoarg_support(::AutoZeroReverse) = TwoArgSupported()

DI.prepare_pullback(f, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()
DI.prepare_pullback(f!, y, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()
prepare_pullback(f, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()
prepare_pullback(f!, y, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()

function DI.value_and_pullback(
function value_and_pullback(
f, ::AutoZeroReverse, x, ty::Tangents{B}, ::NoPullbackExtras
) where {B}
y = f(x)
dxs = ntuple(Returns(zero(x)), Val(B))
return y, Tangents(dxs)
end

function DI.value_and_pullback(
function value_and_pullback(
f!, y, ::AutoZeroReverse, x, ty::Tangents{B}, ::NoPullbackExtras
) where {B}
f!(y, x)
dxs = ntuple(Returns(zero(x)), Val(B))
return y, Tangents(dxs)
end

function DI.value_and_pullback!(
function value_and_pullback!(
f, tx::Tangents, ::AutoZeroReverse, x, ty::Tangents, ::NoPullbackExtras
)
y = f(x)
Expand All @@ -96,7 +96,7 @@ function DI.value_and_pullback!(
return y, tx
end

function DI.value_and_pullback!(
function value_and_pullback!(
f!, y, tx::Tangents, ::AutoZeroReverse, x, ty::Tangents, ::NoPullbackExtras
)
f!(y, x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,4 @@ for backend in vcat(fromprimitive_backends)
@test DifferentiationInterface.pick_batchsize(backend, 100) == 5
end

## Dense backends

test_differentiation(fromprimitive_backends, default_scenarios(); logging=LOGGING);

test_differentiation(
fromprimitive_backends[1],
default_scenarios();
correctness=false,
type_stability=true,
second_order=false,
logging=LOGGING,
);
36 changes: 36 additions & 0 deletions DifferentiationInterface/test/Internals/zero_backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using DifferentiationInterface
using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
using DifferentiationInterfaceTest
using Test

LOGGING = get(ENV, "CI", "false") == "false"

zero_backends = [AutoZeroForward(), AutoZeroReverse()]

for backend in zero_backends
@test check_available(backend)
@test check_twoarg(backend)
end

## Type stability

test_differentiation(
zero_backends,
default_scenarios();
correctness=false,
type_stability=true,
excluded=[:second_derivative],
logging=LOGGING,
)

test_differentiation(
[
SecondOrder(AutoZeroForward(), AutoZeroReverse()),
SecondOrder(AutoZeroReverse(), AutoZeroForward()),
],
default_scenarios(; linalg=false);
correctness=false,
type_stability=true,
first_order=false,
logging=LOGGING,
)
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ include("scenarios/sparse.jl")
include("scenarios/allocfree.jl")
include("scenarios/extensions.jl")

include("utils/zero_backends.jl")
include("utils/misc.jl")
include("utils/filter.jl")

Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ GROUP = get(ENV, "JULIA_DIT_TEST_GROUP", "All")

if GROUP == "Zero" || GROUP == "All"
@testset verbose = false "Zero" begin
include("zero.jl")
include("zero_backends.jl")
end
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
using DifferentiationInterface
using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
using DifferentiationInterfaceTest
using DifferentiationInterfaceTest:
AutoZeroForward,
AutoZeroReverse,
scenario_to_zero,
test_allocfree,
allocfree_scenarios,
remove_batched
scenario_to_zero, test_allocfree, allocfree_scenarios, remove_batched
using ComponentArrays: ComponentArrays
using JLArrays: JLArrays
using StaticArrays: StaticArrays
Expand All @@ -15,18 +11,14 @@ using Test

LOGGING = get(ENV, "CI", "false") == "false"

@test check_available(AutoZeroForward())
@test check_available(AutoZeroReverse())
@test check_twoarg(AutoZeroForward())
@test check_twoarg(AutoZeroReverse())

## Correctness + type stability

test_differentiation(
[AutoZeroForward(), AutoZeroReverse()],
scenario_to_zero.(default_scenarios());
correctness=true,
type_stability=false, # TODO: switch back
default_scenarios();
correctness=false,
type_stability=true,
excluded=[:second_derivative],
logging=LOGGING,
)

Expand All @@ -35,9 +27,9 @@ test_differentiation(
SecondOrder(AutoZeroForward(), AutoZeroReverse()),
SecondOrder(AutoZeroReverse(), AutoZeroForward()),
],
scenario_to_zero.(default_scenarios(; linalg=false));
correctness=true,
type_stability=false, # TODO: switch back
default_scenarios(; linalg=false);
correctness=false,
type_stability=true,
first_order=false,
logging=LOGGING,
)
Expand Down

0 comments on commit 9b0207a

Please sign in to comment.