Skip to content

Commit

Permalink
Mutation for Enzyme and ReverseDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Mar 15, 2024
1 parent be4f4a8 commit ea2731e
Show file tree
Hide file tree
Showing 34 changed files with 491 additions and 274 deletions.
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@ An interface to various automatic differentiation backends in Julia.

This package provides a backend-agnostic syntax to differentiate functions of the following types:

- Allocating: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- Mutating: `f!(y, x)` where `y` is an abstract array and `x` can be a real number or an abstract array
- **Allocating**: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- **Mutating**: `f!(y, x) = nothing` where `y` is an abstract array and `x` can be a real number or an abstract array

## Compatibility

We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Type |
| :------------------------------------------------------------------------------ | :--------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` |
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize=C)` |
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
| Backend | Object | Allocating | Mutating |
| :------------------------------------------------------------------------------ | :-------------------------------------- | :--------- | :------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |||
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |||
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (forward) | `AutoEnzyme(Enzyme.Forward)` |||
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (reverse) | `AutoEnzyme(Enzyme.Reverse)` |||
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` || soon |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |||
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` |||
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |||
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |||

## Example

Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
5 changes: 5 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using DifferentiationInterface.DifferentiationTest
import DifferentiationInterface as DI
using Documenter
using DocumenterMermaid
using JET
using Random
using Test

using ADTypes
using Diffractor: Diffractor
Expand All @@ -23,6 +26,7 @@ PolyesterForwardDiffExt = get_extension(
DI, :DifferentiationInterfacePolyesterForwardDiffExt
)
ReverseDiffExt = get_extension(DI, :DifferentiationInterfaceReverseDiffExt)
TestExt = get_extension(DI, :DifferentiationInterfaceTestExt)
ZygoteExt = get_extension(DI, :DifferentiationInterfaceZygoteExt)

DocMeta.setdocmeta!(
Expand Down Expand Up @@ -58,6 +62,7 @@ makedocs(;
ForwardDiffExt,
PolyesterForwardDiffExt,
ReverseDiffExt,
TestExt,
ZygoteExt,
],
authors="Guillaume Dalle, Adrian Hill",
Expand Down
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ Public = false
These are not part of the public API.

```@autodocs
Modules = [DifferentiationTest]
```
Modules = [DifferentiationTest, Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTestExt)]
```
13 changes: 0 additions & 13 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@ AutoReverseDiff
AutoZygote
```

## Accepted functions

| Backend | `f(x) = y` | `f!(y, x)` |
| -------------------------- | ---------- | ---------- |
| `AutoChainRules` | yes | no |
| `AutoDiffractor` | yes | no |
| `AutoEnzyme` | yes | soon |
| `AutoForwardDiff` | yes | yes |
| `AutoFiniteDiff` | yes | soon |
| `AutoPolyesterForwardDiff` | yes | soon |
| `AutoReverseDiff` | yes | soon |
| `AutoZygote` | yes | no |

## Package extensions

```@meta
Expand Down
2 changes: 1 addition & 1 deletion docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ We do not make any guarantees on their implementation for each backend, or on th

## Mutating functions

In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x)` whenever the output is an array.
In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x) = nothing` whenever the output is an array (beware that it must return `nothing`).
Since they operate in-place and the primal is computed every time, only four operators are defined:

| **Operator** | **mutating with primal** |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,48 @@ using Enzyme:
Active,
Const,
Duplicated,
DuplicatedNoNeed,
Forward,
ForwardMode,
Reverse,
ReverseWithPrimal,
ReverseMode,
autodiff,
gradient,
gradient!,
jacobian

"""
AutoEnzyme(Val(:forward))
AutoEnzyme(Val(:reverse))
AutoEnzyme(Enzyme.Forward)
AutoEnzyme(Enzyme.Reverse)
Construct a forward or reverse mode `AutoEnzyme` backend.
!!! warning
This is the mode convention chosen by DifferentiationInterface.jl, for lack of a global consensus (see [ADTypes.jl#24](https://github.com/SciML/ADTypes.jl/issues/24)).
"""
AutoEnzyme

const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}

function DI.autodiff_mode(::AutoEnzyme)
return error(
"You need to specify the Enzyme mode with `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)`",
)
end

DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode()
DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode()

# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
b = zero(a)
b[i] = one(T)
return b
end

include("forward.jl")
include("reverse.jl")
include("forward_allocating.jl")
include("forward_mutating.jl")

include("reverse_allocating.jl")
include("reverse_mutating.jl")

end # module
30 changes: 0 additions & 30 deletions ext/DifferentiationInterfaceEnzymeExt/forward.jl

This file was deleted.

54 changes: 54 additions & 0 deletions ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## Primitives

function DI.value_and_pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

function DI.value_and_pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
end

function DI.pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
return new_dy
end

function DI.pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
dy .= new_dy
return dy
end

function DI.value_and_pushforward(
backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
return y, dy
end

function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing)
dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
return dy
end

## Utilities

function DI.value_and_jacobian(
backend::AutoForwardEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
jac = jacobian(backend.mode, f, x)
# see https://github.com/EnzymeAD/Enzyme.jl/issues/1332
return y, reshape(jac, length(y), length(x))
end
19 changes: 19 additions & 0 deletions ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## Primitives

function DI.value_and_pushforward!(
y::AbstractArray,
dy::AbstractArray,
backend::AutoForwardEnzyme,
f!,
x,
dx,
extras::Nothing=nothing,
)
dx_sametype = convert(typeof(x), dx)
dy_sametype = convert(typeof(y), dy)
autodiff(
backend.mode, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype)
)
dy .= dy_sametype
return y, dy
end
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
const AutoReverseEnzyme = AutoEnzyme{Val{:reverse}}
DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode()

# see https://enzymead.github.io/Enzyme.jl/stable/pullbacks/

struct MakeFunctionMutating{F}
f::F
end

function (mf!::MakeFunctionMutating)(y::AbstractArray, x)
y .= mf!.f(x)
function (f!::MakeFunctionMutating)(y::AbstractArray, x)
y .= f!.f(x)
return nothing
end

## Primitives

function DI.value_and_pullback!(
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy, extras::Nothing=nothing
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing=nothing
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
Expand All @@ -36,38 +33,64 @@ function DI.value_and_pullback!(
return y, dx
end

function DI.value_and_pullback!(
_dx::Number,
function DI.pullback!(
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing=nothing
)
der = only(autodiff(Reverse, f, Active, Active(x)))
new_dx = dy * only(der)
return new_dx
end

function DI.pullback!(
dx::AbstractArray,
::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::Number,
extras::Nothing=nothing,
)
dx .= zero(eltype(dx))
autodiff(Reverse, f, Active, Duplicated(x, dx))
dx .*= dy
return dx
end

function DI.value_and_pullback!(
dx::Number,
backend::AutoReverseEnzyme,
f,
x::Number,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
mf! = MakeFunctionMutating(f)
_, new_dx = only(autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Active(x)))
return y, new_dx
f! = MakeFunctionMutating(f)
return DI.value_and_pullback!(y, dx, backend, f!, x, dy, extras)
end

function DI.value_and_pullback!(
dx::AbstractArray,
::AutoReverseEnzyme,
backend::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
dx_like_x = zero(x)
mf! = MakeFunctionMutating(f)
autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Duplicated(x, dx_like_x))
dx .= dx_like_x
return y, dx
f! = MakeFunctionMutating(f)
return DI.value_and_pullback!(y, dx, backend, f!, x, dy, extras)
end

## Utilities

function DI.value_and_gradient!(
grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
gradient!(Reverse, grad, f, x)
return y, grad
end

function DI.value_and_gradient(
::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
Expand All @@ -76,10 +99,14 @@ function DI.value_and_gradient(
return y, grad
end

function DI.value_and_gradient!(
function DI.gradient!(
grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
gradient!(Reverse, grad, f, x)
return y, grad
return grad
end

function DI.gradient(::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing)
grad = gradient(Reverse, f, x)
return grad
end
Loading

0 comments on commit ea2731e

Please sign in to comment.