Skip to content

Commit

Permalink
Revamp batch mode for pushforward, pullback and hvp (#412)
Browse files Browse the repository at this point in the history
* Start implementing Tangents

* Finish macros

* Toggle workflows

* Adjust operators

* Update every backend except Enzyme

* Update ForwardDiff

* Update Enzyme

* Fix typos

* Untoggle workflows

* Remaining Batch' s

* Typos

* Typo

* Typo

* Typo

* Typo

* Typo

* Dispatch

* Typo

* Typo

* Solve ambiguity

* Improve dispatch

* Start offloading Tangents{B} to backend extensions

* Continue offloading

* Pullback

* Fix FD

* Add batched Enzyme

* Enzyme

* Typo

* Debug

* Typos

* Typos

* Enzyme overloads

* Typing

* Typo

* Typo Enzyme

* Fix Enzyme

* Assertion not tangent

* Fix FromPrimitive

* Brouillon

* Typos

* FillArrays

* Fix JLArrays and ReverseDiff

* Typo

* Fix ReverseDiff

* Fix dot

* Fix

* No mydot

* Add logging

* Fix zero backends

* No type stability

* Comment out reverse Enzyme over ForwardDiff

* Reactivate logging

* Fix HVP

* Log Enzyme

* Show backend

* Move Tangents constructor

* Refactor pushforward and pullback

* Debug

* Fix

* Fixes

* Fix

* Debug Enzyme

* Extras

* Tapir working

* Fix

* Typos

* Fixs CRC
  • Loading branch information
gdalle authored Aug 29, 2024
1 parent 56fc186 commit 7c60378
Show file tree
Hide file tree
Showing 60 changed files with 1,505 additions and 1,693 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using ChainRulesCore:
using Compat
import DifferentiationInterface as DI
using DifferentiationInterface:
DifferentiateWith, NoPullbackExtras, NoPushforwardExtras, PullbackExtras
DifferentiateWith, NoPullbackExtras, NoPushforwardExtras, PullbackExtras, Tangents

ruleconfig(backend::AutoChainRules) = backend.ruleconfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,34 @@ struct ChainRulesPullbackExtrasSamePoint{Y,PB} <: PullbackExtras
pb::PB
end

DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()
DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents) = NoPullbackExtras()

function DI.prepare_pullback_same_point(
f, backend::AutoReverseChainRules, x, dy, ::PullbackExtras=NoPullbackExtras()
f, backend::AutoReverseChainRules, x, ty::Tangents, ::PullbackExtras=NoPullbackExtras()
)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x)
return ChainRulesPullbackExtrasSamePoint(y, pb)
end

function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPullbackExtras)
function DI.value_and_pullback(
f, backend::AutoReverseChainRules, x, ty::Tangents, ::NoPullbackExtras
)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x)
return y, last(pb(dy))
return y, Tangents(last.(pb.(ty.d)))
end

function DI.value_and_pullback(
f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint
f, ::AutoReverseChainRules, x, ty::Tangents, extras::ChainRulesPullbackExtrasSamePoint
)
@compat (; y, pb) = extras
return copy(y), last(pb(dy))
return copy(y), Tangents(last.(pb.(ty.d)))
end

function DI.pullback(
f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint
f, ::AutoReverseChainRules, x, ty::Tangents, extras::ChainRulesPullbackExtrasSamePoint
)
@compat (; pb) = extras
return last(pb(dy))
return Tangents(last.(pb.(ty.d)))
end
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DifferentiationInterfaceDiffractorExt

using ADTypes: ADTypes, AutoDiffractor
import DifferentiationInterface as DI
using DifferentiationInterface: NoPushforwardExtras
using DifferentiationInterface: NoPushforwardExtras, Tangents
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆

DI.check_available(::AutoDiffractor) = true
Expand All @@ -11,19 +11,21 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()

## Pushforward

DI.prepare_pushforward(f, ::AutoDiffractor, x, dx) = NoPushforwardExtras()
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::Tangents) = NoPushforwardExtras()

function DI.pushforward(f, ::AutoDiffractor, x, dx, ::NoPushforwardExtras)
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
dy = z[TaylorTangentIndex(1)]
return dy
function DI.pushforward(f, ::AutoDiffractor, x, tx::Tangents, ::NoPushforwardExtras)
dys = map(tx.d) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
dy = z[TaylorTangentIndex(1)]
end
return Tangents(dys)
end

function DI.value_and_pushforward(
f, backend::AutoDiffractor, x, dx, extras::NoPushforwardExtras
f, backend::AutoDiffractor, x, tx::Tangents, extras::NoPushforwardExtras
)
return f(x), DI.pushforward(f, backend, x, dx, extras)
return f(x), DI.pushforward(f, backend, x, tx, extras)
end

end
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ using DifferentiationInterface:
DerivativeExtras,
GradientExtras,
JacobianExtras,
HVPExtras,
PullbackExtras,
PushforwardExtras,
NoDerivativeExtras,
NoGradientExtras,
NoJacobianExtras,
NoPullbackExtras,
NoPushforwardExtras,
Tangents,
SingleTangent,
pick_batchsize
using DocStringExtensions
using Enzyme:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
## Pushforward

function DI.prepare_pushforward(f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx)
function DI.prepare_pushforward(
f, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
)
return NoPushforwardExtras()
end

function DI.value_and_pushforward(
f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras
f,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::Tangents,
extras::NoPushforwardExtras,
)
dys = map(tx.d) do dx
DI.pushforward(f, backend, x, dx, extras)
end
y = f(x)
return y, Tangents(dys)
end

function DI.value_and_pushforward(
f,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::Tangents{1},
::NoPushforwardExtras,
)
dx = only(tx)
f_and_df = get_f_and_df(f, backend)
dx_sametype = convert(typeof(x), dx)
x_and_dx = Duplicated(x, dx_sametype)
Expand All @@ -15,12 +36,17 @@ function DI.value_and_pushforward(
else
autodiff(forward_mode(backend), f_and_df, Duplicated, x_and_dx)
end
return y, new_dy
return y, SingleTangent(new_dy)
end

function DI.pushforward(
f, backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx, ::NoPushforwardExtras
f,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::Tangents{1},
::NoPushforwardExtras,
)
dx = only(tx)
f_and_df = get_f_and_df(f, backend)
dx_sametype = convert(typeof(x), dx)
x_and_dx = Duplicated(x, dx_sametype)
Expand All @@ -29,32 +55,32 @@ function DI.pushforward(
else
only(autodiff(forward_mode(backend), f_and_df, DuplicatedNoNeed, x_and_dx))
end
return new_dy
return SingleTangent(new_dy)
end

function DI.value_and_pushforward!(
f,
dy,
ty::Tangents,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
dx,
tx::Tangents,
extras::NoPushforwardExtras,
)
# dy cannot be passed anyway
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, copyto!(dy, new_dy)
y, new_ty = DI.value_and_pushforward(f, backend, x, tx, extras)
return y, copyto!(ty, new_ty)
end

function DI.pushforward!(
f,
dy,
ty::Tangents,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
dx,
tx::Tangents,
extras::NoPushforwardExtras,
)
# dy cannot be passed anyway
return copyto!(dy, DI.pushforward(f, backend, x, dx, extras))
return copyto!(ty, DI.pushforward(f, backend, x, tx, extras))
end

## Gradient
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## Pushforward

function DI.prepare_pushforward(f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, dx)
function DI.prepare_pushforward(
f!, y, ::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::Tangents
)
return NoPushforwardExtras()
end

Expand All @@ -9,9 +11,25 @@ function DI.value_and_pushforward(
y,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
dx,
tx::Tangents,
extras::NoPushforwardExtras,
)
dys = map(tx.d) do dx
DI.pushforward(f!, y, backend, x, dx, extras)
end
f!(y, x)
return y, Tangents(dys)
end

function DI.value_and_pushforward(
f!,
y,
backend::AnyAutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::Tangents{1},
::NoPushforwardExtras,
)
dx = only(tx)
f!_and_df! = get_f_and_df(f!, backend)
dx_sametype = convert(typeof(x), dx)
dy_sametype = make_zero(y)
Expand All @@ -22,5 +40,5 @@ function DI.value_and_pushforward(
else
autodiff(forward_mode(backend), f!_and_df!, Const, y_and_dy, x_and_dx)
end
return y, dy_sametype
return y, SingleTangent(dy_sametype)
end
Loading

0 comments on commit 7c60378

Please sign in to comment.