From f3f7b29427bc7ad0bdc75593740ecfe332e9b799 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 4 Jan 2025 23:24:52 +0100 Subject: [PATCH] fix: check nothing output for Zygote (#667) --- .../DifferentiationInterfaceZygoteExt.jl | 41 ++++++++++++-- .../test/Back/Zygote/test.jl | 53 +++++++++++-------- 2 files changed, 68 insertions(+), 26 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 0f681c9f9..d86f2ec89 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -6,6 +6,24 @@ using ForwardDiff: ForwardDiff using Zygote: ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian +struct ZygoteNothingError <: Exception + f + x + contexts +end + +function Base.showerror(io::IO, e::ZygoteNothingError) + (; f, x, contexts) = e + sig = (typeof(x), map(typeof ∘ DI.unwrap, contexts)...) + return print( + io, + "Zygote failed to differentiate function `$f` with argument types `$sig` (the pullback returned `nothing`).", + ) +end + +check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contexts)) +check_nothing(::Any, f, x, contexts) = nothing + DI.check_available(::AutoZygote) = true DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() @@ -46,6 +64,7 @@ function DI.value_and_pullback( tx = map(ty) do dy first(pb(dy)) end + check_nothing(first(tx), f, x, contexts) return y, tx end @@ -61,6 +80,7 @@ function DI.value_and_pullback( tx = map(ty) do dy first(pb(dy)) end + check_nothing(first(tx), f, x, contexts) return copy(y), tx end @@ -76,6 +96,7 @@ function DI.pullback( tx = map(ty) do dy first(pb(dy)) end + check_nothing(first(tx), f, x, contexts) return tx end @@ -95,6 +116,7 @@ function DI.value_and_gradient( contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} (; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(grad), f, x, contexts) return val, first(grad) end @@ -105,7 +127,9 @@ function DI.gradient( x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} - return first(gradient(f, x, map(DI.unwrap, contexts)...)) + grad = gradient(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(grad), f, x, contexts) + return first(grad) end function DI.value_and_gradient!( @@ -146,8 +170,11 @@ function DI.value_and_jacobian( x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} - return f(x, map(DI.unwrap, contexts)...), - first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506 + y = f(x, map(DI.unwrap, contexts)...) + # https://github.com/FluxML/Zygote.jl/issues/1506 + jac = jacobian(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(jac), f, x, contexts) + return y, first(jac) end function DI.jacobian( @@ -157,7 +184,9 @@ function DI.jacobian( x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} - return first(jacobian(f, x, map(DI.unwrap, contexts)...)) + jac = jacobian(f, x, map(DI.unwrap, contexts)...) + check_nothing(first(jac), f, x, contexts) + return first(jac) end function DI.value_and_jacobian!( @@ -266,7 +295,9 @@ function DI.hessian( contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {C} fc = DI.with_contexts(f, contexts...) - return hessian(fc, x) + hess = hessian(fc, x) + check_nothing(hess, f, x, contexts) + return hess end function DI.hessian!( diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 9204673d9..e83e3c257 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -24,27 +24,38 @@ end ## Dense -test_differentiation( - backends, - default_scenarios(; include_constantified=true); - excluded=[:second_derivative], - logging=LOGGING, -); - -test_differentiation(second_order_backends; logging=LOGGING); - -test_differentiation( - backends[1], - vcat(component_scenarios(), gpu_scenarios()); - excluded=SECOND_ORDER, - logging=LOGGING, -) +@testset "Dense" begin + test_differentiation( + backends, + default_scenarios(; include_constantified=true); + excluded=[:second_derivative], + logging=LOGGING, + ) + + test_differentiation(second_order_backends; logging=LOGGING) + + test_differentiation( + backends[1], + vcat(component_scenarios(), gpu_scenarios()); + excluded=SECOND_ORDER, + logging=LOGGING, + ) +end ## Sparse -test_differentiation( - MyAutoSparse.(vcat(backends, second_order_backends)), - sparse_scenarios(; band_sizes=0:-1); - sparsity=true, - logging=LOGGING, -) +@testset "Sparse" begin + test_differentiation( + MyAutoSparse.(vcat(backends, second_order_backends)), + sparse_scenarios(; band_sizes=0:-1); + sparsity=true, + logging=LOGGING, + ) +end + +## Errors + +@testset "Errors" begin + safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN) + @test_throws "Zygote failed to differentiate" derivative(safe_log, AutoZygote(), 0.0) +end