From ee11b70257f15277614f41cd3fc610e489b0a7fe Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 19 Sep 2024 11:06:20 +0200 Subject: [PATCH] Contexts for Zygote (#474) --- .../DifferentiationInterfaceZygoteExt.jl | 64 +++++++++++++------ .../test/Back/Zygote/test.jl | 7 ++ 2 files changed, 50 insertions(+), 21 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 88de03831..7027c5ca2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -3,13 +3,15 @@ module DifferentiationInterfaceZygoteExt using ADTypes: AutoForwardDiff, AutoZygote import DifferentiationInterface as DI using DifferentiationInterface: + Context, HVPExtras, NoGradientExtras, NoHessianExtras, NoJacobianExtras, NoPullbackExtras, PullbackExtras, - Tangents + Tangents, + unwrap using ForwardDiff: ForwardDiff using Zygote: ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian @@ -25,63 +27,83 @@ struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras pb::PB end -DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents) = NoPullbackExtras() +function DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}) + return NoPullbackExtras() +end function DI.prepare_pullback_same_point( - f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents + f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context} ) - y, pb = pullback(f, x) + y, pb = pullback(f, x, map(unwrap, contexts)...) return ZygotePullbackExtrasSamePoint(y, pb) end -function DI.value_and_pullback(f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents) - y, pb = pullback(f, x) +function DI.value_and_pullback( + f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context} +) + y, pb = pullback(f, x, map(unwrap, contexts)...) tx = map(ty) do dy - only(pb(dy)) + first(pb(dy)) end return y, tx end function DI.value_and_pullback( - f, extras::ZygotePullbackExtrasSamePoint, ::AutoZygote, x, ty::Tangents + f, + extras::ZygotePullbackExtrasSamePoint, + ::AutoZygote, + x, + ty::Tangents, + contexts::Vararg{Context}, ) @compat (; y, pb) = extras tx = map(ty) do dy - only(pb(dy)) + first(pb(dy)) end return copy(y), tx end function DI.pullback( - f, extras::ZygotePullbackExtrasSamePoint, ::AutoZygote, x, ty::Tangents + f, + extras::ZygotePullbackExtrasSamePoint, + ::AutoZygote, + x, + ty::Tangents, + contexts::Vararg{Context}, ) @compat (; pb) = extras tx = map(ty) do dy - only(pb(dy)) + first(pb(dy)) end return tx end ## Gradient -DI.prepare_gradient(f, ::AutoZygote, x) = NoGradientExtras() +DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{Context}) = NoGradientExtras() -function DI.value_and_gradient(f, ::NoGradientExtras, ::AutoZygote, x) - @compat (; val, grad) = withgradient(f, x) - return val, only(grad) +function DI.value_and_gradient( + f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context} +) + @compat (; val, grad) = withgradient(f, x, map(unwrap, contexts)...) + return val, first(grad) end -function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x) - return only(gradient(f, x)) +function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context}) + return first(gradient(f, x, map(unwrap, contexts)...)) end -function DI.value_and_gradient!(f, grad, extras::NoGradientExtras, backend::AutoZygote, x) - y, new_grad = DI.value_and_gradient(f, extras, backend, x) +function DI.value_and_gradient!( + f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context} +) + y, new_grad = DI.value_and_gradient(f, extras, backend, x, contexts...) return y, copyto!(grad, new_grad) end -function DI.gradient!(f, grad, extras::NoGradientExtras, backend::AutoZygote, x) - return copyto!(grad, DI.gradient(f, extras, backend, x)) +function DI.gradient!( + f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context} +) + return copyto!(grad, DI.gradient(f, extras, backend, x, contexts...)) end ## Jacobian diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 13b38500f..f75cd5ba7 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -30,6 +30,13 @@ end test_differentiation(AutoZygote(); excluded=[:second_derivative], logging=LOGGING); +test_differentiation( + AutoZygote(), + default_scenarios(; include_normal=false, include_constantified=true); + second_order=false, + logging=LOGGING, +); + if VERSION >= v"1.10" test_differentiation( AutoZygote(),