Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mutating functions f!(y, x) #41

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ An interface to various automatic differentiation backends in Julia.

## Goal

This package provides a backend-agnostic syntax to differentiate functions `f(x) = y`, where `x` and `y` are either real numbers or abstract arrays.
This package provides a backend-agnostic syntax to differentiate functions of the following types:

It supports in-place versions of every operator, and ensures type stability whenever possible.
- Allocating: `f(x) = y` where `x` and `y` can be real numbers or abstract arrays
- Mutating: `f!(y, x)` where `y` is an abstract array and `x` can be a real number or an abstract array

## Compatibility

We support some of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Type |
|:--------------------------------------------------------------------------------|:-----------------------------------------------------------|
| :------------------------------------------------------------------------------ | :--------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | `AutoEnzyme(Val(:forward))` or `AutoEnzyme(Val(:reverse))` |
Expand Down Expand Up @@ -64,13 +65,12 @@ julia> grad

## Related packages

- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl. We aim to be less generic (one input, one output, first order only) but more efficient (type stability, memory reuse).
- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl. We provide similar functionality (except for the matrix-like behavior) but cover more backends.
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) is the original inspiration for DifferentiationInterface.jl.
- [AutoDiffOperators.jl](https://github.com/oschulz/AutoDiffOperators.jl) is an attempt to bridge ADTypes.jl with AbstractDifferentiation.jl.

## Roadmap

Goals for future releases:

- implement backend-specific cache objects
- support in-place functions `f!(y, x)`
- define user-facing functions to test and benchmark backends against each other
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Base: get_extension
using DifferentiationInterface
using DifferentiationInterface.DifferentiationTest
import DifferentiationInterface as DI
using Documenter
using DocumenterMermaid
Expand Down Expand Up @@ -49,6 +50,7 @@ makedocs(;
modules=[
ADTypes,
DifferentiationInterface,
DifferentiationInterface.DifferentiationTest,
ChainRulesCoreExt,
DiffractorExt,
EnzymeExt,
Expand Down
9 changes: 8 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ These are not part of the public API.

```@autodocs
Modules = [DifferentiationInterface]
Pages = ["backends.jl", "mode.jl", "utils.jl", "DifferentiationTest.jl"]
Public = false
```

## Submodules

These are not part of the public API.

```@autodocs
Modules = [DifferentiationTest]
```
13 changes: 13 additions & 0 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ AutoReverseDiff
AutoZygote
```

## Accepted functions

| Backend | `f(x) = y` | `f!(y, x)` |
| -------------------------- | ---------- | ---------- |
| `AutoChainRules` | yes | no |
| `AutoDiffractor` | yes | no |
| `AutoEnzyme` | yes | soon |
| `AutoForwardDiff` | yes | yes |
| `AutoFiniteDiff` | yes | soon |
| `AutoPolyesterForwardDiff` | yes | soon |
| `AutoReverseDiff` | yes | soon |
| `AutoZygote` | yes | no |

## Package extensions

```@meta
Expand Down
47 changes: 44 additions & 3 deletions docs/src/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ Advanced users are welcome to code more backends and submit pull requests!

Edge labels correspond to the amount of function calls when applying operators to a function $f: \mathbb{R}^n \rightarrow \mathbb{R}^m$.


### Forward mode
### Forward mode, allocating functions

```mermaid
flowchart LR
Expand Down Expand Up @@ -69,7 +68,28 @@ flowchart LR
end
```

### Reverse mode
### Forward mode, mutating functions

```mermaid
flowchart LR
subgraph Multiderivative
value_and_multiderivative!
end

value_and_multiderivative! --> value_and_pushforward!

subgraph Jacobian
value_and_jacobian!
end

value_and_jacobian! --> |n|value_and_pushforward!

subgraph Pushforward
value_and_pushforward!
end
```

### Reverse mode, allocating functions

```mermaid
flowchart LR
Expand Down Expand Up @@ -115,3 +135,24 @@ flowchart LR
pullback --> value_and_pullback
end
```

### Reverse mode, mutating functions

```mermaid
flowchart LR
subgraph Multiderivative
value_and_multiderivative!
end

value_and_multiderivative! --> |m|value_and_pullback!

subgraph Jacobian
value_and_jacobian!
end

value_and_jacobian! --> |m|value_and_pullback!

subgraph Pullback
value_and_pullback!
end
```
23 changes: 21 additions & 2 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ We choose the following terminology for the ones we provide:

Most backends have custom implementations for all of these, which we reuse whenever possible.

### Variants
## Variants

Whenever it makes sense, four variants of the same operator are defined:

| **Operator** | **non-mutating** | **mutating** | **non-mutating with primal** | **mutating with primal** |
| **Operator** | **allocating** | **mutating** | **allocating with primal** | **mutating with primal** |
| :---------------- | :------------------------ | :------------------------- | :---------------------------------- | :----------------------------------- |
| Derivative | [`derivative`](@ref) | N/A | [`value_and_derivative`](@ref) | N/A |
| Multiderivative | [`multiderivative`](@ref) | [`multiderivative!`](@ref) | [`value_and_multiderivative`](@ref) | [`value_and_multiderivative!`](@ref) |
Expand Down Expand Up @@ -49,3 +49,22 @@ This is especially worth it if you plan to call `operator` several times in simi

By default, all the preparation functions return `nothing`.
We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected.

## Mutating functions

In addition to allocating functions `f(x) = y`, we also support mutating functions `f!(y, x)` whenever the output is an array.
Since they operate in-place and the primal is computed every time, only four operators are defined:

| **Operator** | **mutating with primal** |
| :---------------- | :----------------------------------- |
| Multiderivative | [`value_and_multiderivative!`](@ref) |
| Jacobian | [`value_and_jacobian!`](@ref) |
| Pushforward (JVP) | [`value_and_pushforward!`](@ref) |
| Pullback (VJP) | [`value_and_pullback!`](@ref) |

Furthermore, the preparation function takes an additional argument: `prepare_operator(backend, f!, x, y)`.

## Multiple inputs/outputs

Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
If you need more than that, use [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap several objects inside a single `ComponentVector`.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ function DI.value_and_pushforward(
end

function DI.value_and_pushforward!(
dy, backend::AutoForwardChainRules, f, x, dx, extras=nothing
dy::Union{Number,AbstractArray},
backend::AutoForwardChainRules,
f,
x,
dx,
extras=nothing,
)
y, new_dy = DI.value_and_pushforward(backend, f, x, dx, extras)
return y, update!(dy, new_dy)
Expand All @@ -42,7 +47,12 @@ function DI.value_and_pullback(
end

function DI.value_and_pullback!(
dx, backend::AutoReverseChainRules, f, x, dy, extras=nothing
dx::Union{Number,AbstractArray},
backend::AutoReverseChainRules,
f,
x,
dy,
extras=nothing,
)
y, new_dx = DI.value_and_pullback(backend, f, x, dy, extras)
return y, update!(dx, new_dx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ function DI.value_and_pushforward(::AutoDiffractor, f, x, dx, extras::Nothing=no
return y, dy
end

function DI.value_and_pushforward!(dy, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing)
function DI.value_and_pushforward!(
dy::Union{Number,AbstractArray}, ::AutoDiffractor, f, x, dx, extras::Nothing=nothing
)
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
y, new_dy = vpff((dx,))
return y, update!(dy, new_dy)
Expand Down
8 changes: 4 additions & 4 deletions ext/DifferentiationInterfaceEnzymeExt/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ DI.autodiff_mode(::AutoForwardEnzyme) = DI.ForwardMode()
## Primitives

function DI.value_and_pushforward!(
_dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing
) where {X,Y<:Real}
_dy::Real, ::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

function DI.value_and_pushforward!(
dy::Y, ::AutoForwardEnzyme, f, x::X, dx, extras::Nothing=nothing
) where {X,Y<:AbstractArray}
dy::AbstractArray, ::AutoForwardEnzyme, f, x, dx, extras::Nothing=nothing
)
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
Expand Down
31 changes: 23 additions & 8 deletions ext/DifferentiationInterfaceEnzymeExt/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,49 @@ end
## Primitives

function DI.value_and_pullback!(
_dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:Number,Y<:Union{Real,Nothing}}
_dx::Number, ::AutoReverseEnzyme, f, x::Number, dy, extras::Nothing=nothing
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
return y, new_dx
end

function DI.value_and_pullback!(
dx::X, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
dx::AbstractArray,
::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::Number,
extras::Nothing=nothing,
)
dx .= zero(eltype(dx))
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx))
dx .*= dy
return y, dx
end

function DI.value_and_pullback!(
_dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:Number,Y<:AbstractArray}
_dx::Number,
::AutoReverseEnzyme,
f,
x::Number,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
mf! = MakeFunctionMutating(f)
_, new_dx = only(autodiff(Reverse, mf!, Const, Duplicated(y, copy(dy)), Active(x)))
return y, new_dx
end

function DI.value_and_pullback!(
dx, ::AutoReverseEnzyme, f, x::X, dy::Y, extras::Nothing=nothing
) where {X<:AbstractArray,Y<:AbstractArray}
dx::AbstractArray,
::AutoReverseEnzyme,
f,
x::AbstractArray,
dy::AbstractArray,
extras::Nothing=nothing,
)
y = f(x)
dx_like_x = zero(x)
mf! = MakeFunctionMutating(f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ const FUNCTION_NOT_INPLACE = Val{false}
## Primitives

function DI.value_and_pushforward!(
dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {Y<:Number,fdtype}
_dy::Number, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {fdtype}
y = f(x)
step(t::Number)::Number = f(x .+ t .* dx)
new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y)
return y, new_dy
end

function DI.value_and_pushforward!(
dy::Y, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {Y<:AbstractArray,fdtype}
dy::AbstractArray, ::AutoFiniteDiff{fdtype}, f, x, dx, extras::Nothing=nothing
) where {fdtype}
y = f(x)
step(t::Number)::AbstractArray = f(x .+ t .* dx)
finite_difference_gradient!(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module DifferentiationInterfaceForwardDiffExt

using ADTypes: AutoForwardDiff
import DifferentiationInterface as DI
using DiffResults: DiffResults
using DocStringExtensions
using ForwardDiff:
Chunk,
Dual,
DerivativeConfig,
GradientConfig,
JacobianConfig,
Tag,
derivative,
derivative!,
extract_derivative,
extract_derivative!,
gradient,
gradient!,
jacobian,
jacobian!,
value
using LinearAlgebra: mul!

choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()

tag_type(::F, ::V) where {F,V<:Number} = Tag{F,V}
tag_type(::F, ::AbstractArray{V}) where {F,V<:Number} = Tag{F,V}

include("non_mutating.jl")
include("mutating.jl")

end # module
Loading
Loading