Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutation for Enzyme and ReverseDiff #43

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@ An interface to various automatic differentiation backends in Julia.

This package provides a backend-agnostic syntax to differentiate functions of the following types:

- Allocating: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- Mutating: `f!(y, x)` where `y` is an abstract array and `x` can be a real number or an abstract array
- **Allocating**: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- **Mutating**: `f!(y, x) = nothing` where `y` is an abstract array and `x` can be a real number or an abstract array

## Compatibility

We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Type |
| :------------------------------------------------------------------------------ | :--------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` |
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` |
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize=C)` |
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` |
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` |
| Backend | Object | Allocating | Mutating |
| :------------------------------------------------------------------------------ | :-------------------------------------- | :--------- | :------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` | ✓ | ✗ |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` | ✓ | ✗ |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (forward) | `AutoEnzyme(Enzyme.Forward)` | ✓ | ✓ |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) (reverse) | `AutoEnzyme(Enzyme.Reverse)` | ✓ | ✓ |
| [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) | `AutoFiniteDiff()` | ✓ | soon |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | `AutoForwardDiff()` | ✓ | ✓ |
| [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl) | `AutoPolyesterForwardDiff(; chunksize)` | ✓ | ✓ |
| [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) | `AutoReverseDiff()` | ✓ | ✓ |
| [Zygote.jl](https://github.com/FluxML/Zygote.jl) | `AutoZygote()` | ✓ | ✗ |

## Example

Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
5 changes: 5 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using DifferentiationInterface.DifferentiationTest
import DifferentiationInterface as DI
using Documenter
using DocumenterMermaid
using JET
using Random
using Test

using ADTypes
using Diffractor: Diffractor
Expand All @@ -23,6 +26,7 @@ PolyesterForwardDiffExt = get_extension(
DI, :DifferentiationInterfacePolyesterForwardDiffExt
)
ReverseDiffExt = get_extension(DI, :DifferentiationInterfaceReverseDiffExt)
TestExt = get_extension(DI, :DifferentiationInterfaceTestExt)
ZygoteExt = get_extension(DI, :DifferentiationInterfaceZygoteExt)

DocMeta.setdocmeta!(
Expand Down Expand Up @@ -58,6 +62,7 @@ makedocs(;
ForwardDiffExt,
PolyesterForwardDiffExt,
ReverseDiffExt,
TestExt,
ZygoteExt,
],
authors="Guillaume Dalle, Adrian Hill",
Expand Down
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ Public = false
These are not part of the public API.

```@autodocs
Modules = [DifferentiationTest]
```
Modules = [DifferentiationTest, Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTestExt)]
```
13 changes: 0 additions & 13 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@ AutoReverseDiff
AutoZygote
```

## Accepted functions

| Backend | `f(x) = y` | `f!(y, x)` |
| -------------------------- | ---------- | ---------- |
| `AutoChainRules` | yes | no |
| `AutoDiffractor` | yes | no |
| `AutoEnzyme` | yes | soon |
| `AutoForwardDiff` | yes | yes |
| `AutoFiniteDiff` | yes | soon |
| `AutoPolyesterForwardDiff` | yes | soon |
| `AutoReverseDiff` | yes | soon |
| `AutoZygote` | yes | no |

## Package extensions

```@meta
Expand Down
2 changes: 1 addition & 1 deletion docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ We do not make any guarantees on their implementation for each backend, or on th

## Mutating functions

In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x)` whenever the output is an array.
In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x) = nothing` whenever the output is an array (beware that it must return `nothing`).
Since they operate in-place and the primal is computed every time, only four operators are defined:

| **Operator** | **mutating with primal** |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,48 @@ using Enzyme:
Active,
Const,
Duplicated,
DuplicatedNoNeed,
Forward,
ForwardMode,
Reverse,
ReverseWithPrimal,
ReverseMode,
autodiff,
gradient,
gradient!,
jacobian

"""
AutoEnzyme(Val(:forward))
AutoEnzyme(Val(:reverse))
AutoEnzyme(Enzyme.Forward)
AutoEnzyme(Enzyme.Reverse)

Construct a forward or reverse mode `AutoEnzyme` backend.

!!! warning
This is the mode convention chosen by DifferentiationInterface.jl, for lack of a global consensus (see [ADTypes.jl#24](https://github.com/SciML/ADTypes.jl/issues/24)).
"""
AutoEnzyme

const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}

function DI.autodiff_mode(::AutoEnzyme)
return error(
"You need to specify the Enzyme mode with `AutoEnzyme(Enzyme.Forward)` or `AutoEnzyme(Enzyme.Reverse)`",
)
end

DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode()
DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode()

# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
b = zero(a)
b[i] = one(T)
return b
end

include("forward.jl")
include("reverse.jl")
include("forward_allocating.jl")
include("forward_mutating.jl")

include("reverse_allocating.jl")
include("reverse_mutating.jl")

end # module
30 changes: 0 additions & 30 deletions ext/DifferentiationInterfaceEnzymeExt/forward.jl

This file was deleted.

54 changes: 54 additions & 0 deletions ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## Primitives

function DI.value_and_pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

function DI.value_and_pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
end

function DI.pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
return new_dy
end

function DI.pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
dy .= new_dy
return dy
end

function DI.value_and_pushforward(
backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
return y, dy
end

function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing)
dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
return dy
end

## Utilities

function DI.value_and_jacobian(
backend::AutoForwardEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
jac = jacobian(backend.mode, f, x)
# see https://github.com/EnzymeAD/Enzyme.jl/issues/1332
return y, reshape(jac, length(y), length(x))
end
19 changes: 19 additions & 0 deletions ext/DifferentiationInterfaceEnzymeExt/forward_mutating.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## Primitives

function DI.value_and_pushforward!(
y::AbstractArray,
dy::AbstractArray,
backend::AutoForwardEnzyme,
f!,
x,
dx,
extras::Nothing=nothing,
)
dx_sametype = convert(typeof(x), dx)
dy_sametype = convert(typeof(y), dy)
autodiff(
backend.mode, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype)
)
dy .= dy_sametype
return y, dy
end
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
const AutoReverseEnzyme = AutoEnzyme{Val{:reverse}}
DI.autodiff_mode(::AutoReverseEnzyme) = DI.ReverseMode()

# see https://enzymead.github.io/Enzyme.jl/stable/pullbacks/

struct MakeFunctionMutating{F}
f::F
end

function (mf!::MakeFunctionMutating)(y::AbstractArray, x)
y .= mf!.f(x)
function (f!::MakeFunctionMutating)(y::AbstractArray, x)
y .= f!.f(x)
return nothing
end

## Primitives

function DI.value_and_pullback!(
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy, extras::Nothing=nothing
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing=nothing
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
Expand All @@ -36,38 +33,64 @@ function DI.value_and_pullback!(
return y, dx
end

function DI.value_and_pullback!(
_dx::Number,
function DI.pullback!(
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy::Number, extras::Nothing=nothing
)
der = only(autodiff(Reverse, f, Active, Active(x)))
new_dx = dy * only(der)
return new_dx
end

function DI.pullback!(
dx::AbstractArray,
::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::Number,
extras::Nothing=nothing,
)
dx .= zero(eltype(dx))
autodiff(Reverse, f, Active, Duplicated(x, dx))
dx .*= dy
return dx
end

function DI.value_and_pullback!(
dx::Number,
backend::AutoReverseEnzyme,
f,
x::Number,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
mf! = MakeFunctionMutating(f)
_, new_dx = only(autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Active(x)))
return y, new_dx
f! = MakeFunctionMutating(f)
return DI.value_and_pullback!(y, dx, backend, f!, x, dy, extras)
end

function DI.value_and_pullback!(
dx::AbstractArray,
::AutoReverseEnzyme,
backend::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
dx_like_x = zero(x)
mf! = MakeFunctionMutating(f)
autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Duplicated(x, dx_like_x))
dx .= dx_like_x
return y, dx
f! = MakeFunctionMutating(f)
return DI.value_and_pullback!(y, dx, backend, f!, x, dy, extras)
end

## Utilities

function DI.value_and_gradient!(
grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
gradient!(Reverse, grad, f, x)
return y, grad
end

function DI.value_and_gradient(
::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
Expand All @@ -76,10 +99,14 @@ function DI.value_and_gradient(
return y, grad
end

function DI.value_and_gradient!(
function DI.gradient!(
grad::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing
)
y = f(x)
gradient!(Reverse, grad, f, x)
return y, grad
return grad
end

function DI.gradient(::AutoReverseEnzyme, f, x::AbstractArray, extras::Nothing=nothing)
grad = gradient(Reverse, f, x)
return grad
end
Loading
Loading