Skip to content

Commit

Permalink
fix: improve ForwardDiff tagging for HVP (#596)
Browse files Browse the repository at this point in the history
* Improve ForwardDiff tagging

* Remove tag unwrapping for FixTail

* Cov

* Bump DI
  • Loading branch information
gdalle authored Oct 30, 2024
1 parent 16c0194 commit 346834e
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 47 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.17"
version = "0.6.18"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ using DifferentiationInterface:
BatchSizeSettings,
Cache,
Constant,
PrepContext,
Context,
FixTail,
DerivativePrep,
DifferentiateWith,
GradientPrep,
Expand All @@ -21,6 +23,7 @@ using DifferentiationInterface:
SecondOrder,
inner,
outer,
shuffled_gradient,
unwrap,
with_contexts
import ForwardDiff.DiffResults as DR
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,6 @@
struct ForwardDiffOverSomethingHVPWrapper{F}
f::F
end

"""
tag_backend_hvp(f, ::AutoForwardDiff, x)
Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like.
"""
tag_backend_hvp(f, backend::AutoForwardDiff, x) = backend

function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize}
tag = ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x))
return AutoForwardDiff{chunksize,typeof(tag)}(tag)
end

struct ForwardDiffOverSomethingHVPPrep{B<:AutoForwardDiff,G,E<:PushforwardPrep} <: HVPPrep
tagged_outer_backend::B
inner_gradient::G
outer_pushforward_prep::E
struct ForwardDiffOverSomethingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep
inner_gradient_prep::E1
outer_pushforward_prep::E2
end

function DI.prepare_hvp(
Expand All @@ -27,65 +10,94 @@ function DI.prepare_hvp(
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
rewrap = Rewrap(contexts...)
tagged_outer_backend = tag_backend_hvp(f, outer(backend), x)
T = tag_type(f, tagged_outer_backend, x)
T = tag_type(shuffled_gradient, outer(backend), x)
xdual = make_dual(T, x, tx)
gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
# TODO: get rid of closure?
function inner_gradient(x, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...)
end
outer_pushforward_prep = DI.prepare_pushforward(
inner_gradient, tagged_outer_backend, x, tx, contexts...
inner_gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return ForwardDiffOverSomethingHVPPrep(
tagged_outer_backend, inner_gradient, outer_pushforward_prep
outer_pushforward_prep = DI.prepare_pushforward(
shuffled_gradient, outer(backend), x, tx, new_contexts...
)
return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep)
end

function DI.hvp(
f::F,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return DI.pushforward(
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts...
)
end

function DI.hvp!(
f::F,
tg::NTuple,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
DI.pushforward!(
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return DI.pushforward!(
shuffled_gradient,
tg,
outer_pushforward_prep,
outer(backend),
x,
tx,
new_contexts...,
)
return tg
end

function DI.gradient_and_hvp(
f::F,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
return DI.value_and_pushforward(
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts...
)
end

Expand All @@ -94,14 +106,28 @@ function DI.gradient_and_hvp!(
grad,
tg::NTuple,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
(; inner_gradient_prep, outer_pushforward_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
Constant(f),
PrepContext(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
new_grad, _ = DI.value_and_pushforward!(
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
shuffled_gradient,
tg,
outer_pushforward_prep,
outer(backend),
x,
tx,
new_contexts...,
)
return copyto!(grad, new_grad), tg
end
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
end

_translate(::Type{T}, ::Val{B}, c::Constant) where {T,B} = unwrap(c)
_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = unwrap(c)

function _translate(::Type{T}, ::Val{B}, c::Cache) where {T,B}
c0 = unwrap(c)
Expand Down
11 changes: 11 additions & 0 deletions DifferentiationInterface/src/first_order/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,14 @@ function shuffled_gradient(
) where {F,C}
return gradient(f, backend, x, rewrap(unannotated_contexts...)...)
end

function shuffled_gradient(
x,
f::F,
prep::GradientPrep,
backend::AbstractADType,
rewrap::Rewrap{C},
unannotated_contexts::Vararg{Any,C},
) where {F,C}
return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...)
end
6 changes: 6 additions & 0 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ unwrap(c::Cache) = c.data

Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data

struct PrepContext{T<:Prep} <: Context
data::T
end

unwrap(c::PrepContext) = c.data

struct Rewrap{C,T}
context_makers::T
function Rewrap(contexts::Vararg{Context,C}) where {C}
Expand Down

4 comments on commit 346834e

@gdalle
Copy link
Member Author

@gdalle gdalle commented on 346834e Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=DifferentiationInterface

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Changing package repo URL not allowed, please submit a pull request with the URL change to the target registry and retry.

@gdalle
Copy link
Member Author

@gdalle gdalle commented on 346834e Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=DifferentiationInterface

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118361

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a DifferentiationInterface-v0.6.18 -m "<description of version>" 346834e1f95a90b21a161620e53a3cf5bee95de0
git push origin DifferentiationInterface-v0.6.18

Please sign in to comment.