Skip to content

Commit

Permalink
fix(batchedad): restructure how rrules are defined
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 17, 2024
1 parent 3c88ef8 commit 3e77701
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 43 deletions.
2 changes: 1 addition & 1 deletion ext/LuxReverseDiffExt/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
x::TrackedArray, ps::TrackedArray, ::CPUDevice)

# Nested AD
@grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian(
@grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian_internal(
f, backend::AbstractADType, x::TrackedArray)
2 changes: 1 addition & 1 deletion ext/LuxTrackerExt/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray)
end

# Nested AD
@grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian(
@grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian_internal(
f, backend::AbstractADType, x::TrackedArray)
6 changes: 6 additions & 0 deletions src/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ function vector_jacobian_product_impl end
function jacobian_vector_product end
function jacobian_vector_product_impl end

## Call Structure for `batched_jacobian`
## Users call `batched_jacobian(f, ad, x)`
## This calls either `batched_jacobian_internal(f, ad, x)` or
## `batched_jacobian_internal(f, ad, x, y)` (if `f` is a `ComposedFunction` of correct form)
## We define rrule on `batched_jacobian_internal`.
function batched_jacobian end
function batched_jacobian_internal end
function batched_jacobian_impl end

#! format: off
Expand Down
70 changes: 39 additions & 31 deletions src/autodiff/batched_autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,48 +1,27 @@
function batched_jacobian(f::F, backend::AbstractADType, x::AbstractArray) where {F}
return batched_jacobian_impl(f, backend, x)
return batched_jacobian_internal(f, backend, x)
end

for fType in AD_CONVERTIBLE_FUNCTIONS
@eval function batched_jacobian(f::$(fType), backend::AbstractADType, x::AbstractArray)
f̂, y = rewrite_autodiff_call(f)
return batched_jacobian(f̂, backend, x, y)
return batched_jacobian_internal(f̂, backend, x, y)
end
end

function batched_jacobian(f::F, backend::AbstractADType, x::AbstractArray, y) where {F}
return batched_jacobian_impl(Base.Fix2(f, y), backend, x)
function batched_jacobian_internal(
f::F, backend::AbstractADType, x::AbstractArray, y) where {F}
return batched_jacobian_internal(Base.Fix2(f, y), backend, x)
end

# These are useful to extend Nested AD for non-chain rules backends
function CRC.rrule(::typeof(batched_jacobian), f::F,
backend::AbstractADType, x::AbstractArray) where {F}
return CRC.rrule_via_ad(rule_config(Val(:Zygote)), batched_jacobian, f, backend, x)
end

function CRC.rrule(::typeof(batched_jacobian), f::F,
function CRC.rrule(::typeof(batched_jacobian_internal), f::F,
backend::AbstractADType, x::AbstractArray, y) where {F}
return CRC.rrule_via_ad(rule_config(Val(:Zygote)), batched_jacobian, f, backend, x, y)
return CRC.rrule_via_ad(
rule_config(Val(:Zygote)), batched_jacobian_internal, f, backend, x, y)
end

# For simplicity we will reuse the same rrule as below
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batched_jacobian),
f::F, backend::AbstractADType, x::AbstractArray) where {F}
= let f = f
(x, _) -> f(x)
end

res, ∇batched_jacobian_full = CRC.rrule_via_ad(
cfg, batched_jacobian, f̂, backend, x, nothing)
∇batched_jacobian = let ∇batched_jacobian_full = ∇batched_jacobian_full
Δ -> begin
_, _, _, ∂x, _ = ∇batched_jacobian_full(CRC.unthunk(Δ))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
end
return res, ∇batched_jacobian
end

function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batched_jacobian),
function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batched_jacobian_internal),
f::F, backend::AbstractADType, x::AbstractArray, y) where {F}
grad_fn = let cfg = cfg
(f̂, x, args...) -> begin
Expand All @@ -52,7 +31,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batched_jacobian)
end

jac_fn = let backend = backend
(f̂, x̃) -> batched_jacobian_impl(f̂, backend, x̃)
(f̂, x̃) -> batched_jacobian_internal(f̂, backend, x̃)
end

res, ∇autodiff_jacobian = CRC.rrule_via_ad(
Expand All @@ -66,6 +45,35 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batched_jacobian)
return res, ∇batched_jacobian
end

function CRC.rrule(::typeof(batched_jacobian_internal), f::F,
backend::AbstractADType, x::AbstractArray) where {F}
return CRC.rrule_via_ad(
rule_config(Val(:Zygote)), batched_jacobian_internal, f, backend, x)
end

function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(batched_jacobian_internal),
f::F, backend::AbstractADType, x::AbstractArray) where {F}
= let f = f
(x, _) -> f(x)
end

res, ∇batched_jacobian_full = CRC.rrule_via_ad(
cfg, batched_jacobian_internal, f̂, backend, x, nothing)
∇batched_jacobian = let ∇batched_jacobian_full = ∇batched_jacobian_full
Δ -> begin
_, _, _, ∂x, _ = ∇batched_jacobian_full(CRC.unthunk(Δ))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
end
return res, ∇batched_jacobian
end

# We need this intermediate call to ensure that there aren't any ambiguities
function batched_jacobian_internal(
f::F, backend::AbstractADType, x::AbstractArray) where {F}
return batched_jacobian_impl(f, backend, x)
end

# ForwardDiff.jl Implementation
function batched_jacobian_impl(
f::F, backend::AutoForwardDiff{CK}, x::AbstractArray) where {F, CK}
Expand Down
17 changes: 7 additions & 10 deletions test/autodiff/batched_autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ end
using ComponentArrays, ForwardDiff, Zygote

rng = StableRNG(12345)
cdev = cpu_device()

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
models = (
Expand Down Expand Up @@ -114,33 +113,30 @@ end
loss_function_simple, model, X, ps, st)

@test ∂x_batched∂x_simple atol=1.0e-3 rtol=1.0e-3

∂ps_batched = ComponentArray(∂ps_batched |> cdev)
∂ps_simple = ComponentArray(∂ps_simple |> cdev)
@test ∂ps_batched∂ps_simple atol=1.0e-3 rtol=1.0e-3
@test check_approx(∂ps_batched, ∂ps_simple; atol=1.0e-3, rtol=1.0e-3)

ps = ps |> cpu_device() |> ComponentArray |> dev

_, ∂x_batched2, ∂ps_batched2, _ = Zygote.gradient(
loss_function_batched, model, X, ps, st)

@test ∂x_batched2∂x_batched atol=1.0e-3 rtol=1.0e-3

∂ps_batched2 = ComponentArray(∂ps_batched2 |> cdev)
@test ∂ps_batched2∂ps_batched atol=1.0e-3 rtol=1.0e-3
@test check_approx(∂ps_batched2, ∂ps_batched; atol=1.0e-3, rtol=1.0e-3)
end
end
end

@testitem "Nested AD: Batched Jacobian Single Input" setup=[SharedTestSetup] tags=[:autodiff] begin
using ForwardDiff, Zygote, Tracker, ReverseDiff
using Zygote, Tracker, ReverseDiff

rng = StableRNG(12345)
sq_fn(x) = x .^ 2

sumabs2_fd(x) = sum(abs2, batched_jacobian(sq_fn, AutoForwardDiff(), x))
sumabs2_zyg(x) = sum(abs2, batched_jacobian(sq_fn, AutoZygote(), x))

true_gradient(x) = 8 .* x

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
x = rand(rng, Float32, 4, 2) |> aType

Expand All @@ -150,7 +146,8 @@ end
∂x1_tr = only(Tracker.gradient(sumabs2_zyg, x))
∂x2_zyg = only(Zygote.gradient(sumabs2_fd, x))
∂x2_tr = only(Tracker.gradient(sumabs2_fd, x))
∂x_gt = ForwardDiff.gradient(sumabs2_fd, x)

∂x_gt = true_gradient(x)

@test ∂x1_zyg∂x_gt atol=1.0e-3 rtol=1.0e-3
@test ∂x1_tr∂x_gt atol=1.0e-3 rtol=1.0e-3
Expand Down

2 comments on commit 3e77701

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113339

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.65 -m "<description of version>" 3e7770124a0fcdae040d6297a26825c1336c2ddf
git push origin v0.5.65

Please sign in to comment.