Skip to content

Commit

Permalink
fix: unthunk ChainRules pullback outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jan 6, 2025
1 parent 6604be2 commit 9f43dc1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ChainRulesCore:
NoTangent,
RuleConfig,
frule_via_ad,
rrule_via_ad
rrule_via_ad,
unthunk
import DifferentiationInterface as DI

ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function DI.value_and_pullback(
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
tx = map(ty) do dy
pb(dy)[2]
unthunk(pb(dy)[2])
end
return y, tx
end
Expand All @@ -54,7 +54,7 @@ function DI.value_and_pullback(
) where {C}
(; y, pb) = prep
tx = map(ty) do dy
pb(dy)[2]
unthunk(pb(dy)[2])
end
return copy(y), tx
end
Expand All @@ -69,7 +69,7 @@ function DI.pullback(
) where {C}
(; pb) = prep
tx = map(ty) do dy
pb(dy)[2]
unthunk(pb(dy)[2])
end
return tx
end

0 comments on commit 9f43dc1

Please sign in to comment.