Skip to content

Commit

Permalink
fix: check nothing output for Zygote (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Jan 4, 2025
1 parent 9df2763 commit f3f7b29
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@ using ForwardDiff: ForwardDiff
using Zygote:
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian

struct ZygoteNothingError <: Exception
f
x
contexts
end

function Base.showerror(io::IO, e::ZygoteNothingError)
(; f, x, contexts) = e
sig = (typeof(x), map(typeof DI.unwrap, contexts)...)
return print(
io,
"Zygote failed to differentiate function `$f` with argument types `$sig` (the pullback returned `nothing`).",
)
end

check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contexts))
check_nothing(::Any, f, x, contexts) = nothing

DI.check_available(::AutoZygote) = true
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()

Expand Down Expand Up @@ -46,6 +64,7 @@ function DI.value_and_pullback(
tx = map(ty) do dy
first(pb(dy))
end
check_nothing(first(tx), f, x, contexts)
return y, tx
end

Expand All @@ -61,6 +80,7 @@ function DI.value_and_pullback(
tx = map(ty) do dy
first(pb(dy))
end
check_nothing(first(tx), f, x, contexts)
return copy(y), tx
end

Expand All @@ -76,6 +96,7 @@ function DI.pullback(
tx = map(ty) do dy
first(pb(dy))
end
check_nothing(first(tx), f, x, contexts)
return tx
end

Expand All @@ -95,6 +116,7 @@ function DI.value_and_gradient(
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
(; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...)
check_nothing(first(grad), f, x, contexts)
return val, first(grad)
end

Expand All @@ -105,7 +127,9 @@ function DI.gradient(
x,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
return first(gradient(f, x, map(DI.unwrap, contexts)...))
grad = gradient(f, x, map(DI.unwrap, contexts)...)
check_nothing(first(grad), f, x, contexts)
return first(grad)
end

function DI.value_and_gradient!(
Expand Down Expand Up @@ -146,8 +170,11 @@ function DI.value_and_jacobian(
x,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
return f(x, map(DI.unwrap, contexts)...),
first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506
y = f(x, map(DI.unwrap, contexts)...)
# https://github.com/FluxML/Zygote.jl/issues/1506
jac = jacobian(f, x, map(DI.unwrap, contexts)...)
check_nothing(first(jac), f, x, contexts)
return y, first(jac)
end

function DI.jacobian(
Expand All @@ -157,7 +184,9 @@ function DI.jacobian(
x,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
return first(jacobian(f, x, map(DI.unwrap, contexts)...))
jac = jacobian(f, x, map(DI.unwrap, contexts)...)
check_nothing(first(jac), f, x, contexts)
return first(jac)
end

function DI.value_and_jacobian!(
Expand Down Expand Up @@ -266,7 +295,9 @@ function DI.hessian(
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {C}
fc = DI.with_contexts(f, contexts...)
return hessian(fc, x)
hess = hessian(fc, x)
check_nothing(hess, f, x, contexts)
return hess
end

function DI.hessian!(
Expand Down
53 changes: 32 additions & 21 deletions DifferentiationInterface/test/Back/Zygote/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,38 @@ end

## Dense

test_differentiation(
backends,
default_scenarios(; include_constantified=true);
excluded=[:second_derivative],
logging=LOGGING,
);

test_differentiation(second_order_backends; logging=LOGGING);

test_differentiation(
backends[1],
vcat(component_scenarios(), gpu_scenarios());
excluded=SECOND_ORDER,
logging=LOGGING,
)
@testset "Dense" begin
test_differentiation(
backends,
default_scenarios(; include_constantified=true);
excluded=[:second_derivative],
logging=LOGGING,
)

test_differentiation(second_order_backends; logging=LOGGING)

test_differentiation(
backends[1],
vcat(component_scenarios(), gpu_scenarios());
excluded=SECOND_ORDER,
logging=LOGGING,
)
end

## Sparse

test_differentiation(
MyAutoSparse.(vcat(backends, second_order_backends)),
sparse_scenarios(; band_sizes=0:-1);
sparsity=true,
logging=LOGGING,
)
@testset "Sparse" begin
test_differentiation(
MyAutoSparse.(vcat(backends, second_order_backends)),
sparse_scenarios(; band_sizes=0:-1);
sparsity=true,
logging=LOGGING,
)
end

## Errors

@testset "Errors" begin
safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN)
@test_throws "Zygote failed to differentiate" derivative(safe_log, AutoZygote(), 0.0)
end

0 comments on commit f3f7b29

Please sign in to comment.