Skip to content

Commit

Permalink
Basic rewrite of the package 2023 edition Part I: ADVI (#49)
Browse files Browse the repository at this point in the history
* refactor ADVI, change gradient operation interface

* remove unused file, remove unused dependency

* fix ADVI elbo computation more efficiently

* fix missing entropy regularization term

* add LogDensityProblem interface

* refactor use bijectors directly instead of transformed distributions

This is to avoid having to reconstruct transformed distributions all
the time. The direct use of bijectors also avoids going through lots
of abstraction layers that could break.

Instead, transformed distributions could be constructed only once when
returing the VI result.

* fix type restrictions

* remove unused file

* fix use of with_logabsdet_jacobian

* restructure project; move the main VI routine to its own file

* remove redundant import

* restructure project into more modular objective estimators

* migrate to AbstractDifferentiation

* add location scale pre-packaged variational family, add functors

* Revert "migrate to AbstractDifferentiation"

This reverts commit 2a4514e.

* fix use optimized MvNormal specialization, add logpdf for Loc.Scale.

* remove dead code

* fix location-scale logpdf

- Full Monte Carlo ELBO estimation now works. I checked.

* add sticking-the-landing (STL) estimator

* migrate to Optimisers.jl

* remove execution time measurement (replace later with somethin else)

* fix use multiple dispatch for deciding whether to stop entropy grad.

* add termination decision, callback arguments

* add Base.show to modules

* add interface calling `restructure`, rename rebuild -> restructure

* add estimator state interface, add control variate interface to ADVI

* fix `show(advi)` to show control variate

* fix simplify `show(advi.control_variate)`

* fix type piracy by wrapping location-scale bijected distribution

* remove old AdvancedVI custom optimizers

* fix Location Scale to not depend on Bijectors

* fix RNG namespace

* fix location scale logpdf bug

* add Accessors dependency

* add location scale, autodiff tests

* add Accessors import statement

* remove optimiser tests

* refactor slightly generalize the distribution tests for the future

* migrate to SimpleUnPack, migrate to ADTypes

* rename vi.jl to optimize.jl

* fix estimate_gradient to use adtypes

* add exact inference tests

* remove Turing dependency in tests

* remove unused projection

* remove redundant `ADVIEnergy` object (now baked into `ADVI`)

* add more tests, fix rng seed for tests

* add more tests, fix seed for tests

* fix non-determinism bug

* fix test hyperparameters so that tests pass, minor cleanups

* fix minor reorganization

* add missing files

* fix add missing file, rename adbackend argument

* fix errors

* rename test suite

* refactor renamed arguments for ADVI to be shorter

* fix compile error in advi test

* add initial doc

* remove unused epsilon argument in location scale

* add project file for documenter

* refactor STL gradient calculation to use multiple dispatch

* fix type bugs, relax test threshold for the exact inference tests

* refactor derivative utils to match NormalizingFlows.jl with extras

* add documentation, refactor optimize

* fix bug missing extension

* remove tracker from tests

* remove export for internal derivative utils

* fix test errors, old interface

* fix wrong derivative interface, add documentation

* update documentation

* add doc build CI

* remove convergence criterion for now

* remove outdated export

* update documentation

* update documentation

* update documentation

* fix type error in test

* remove default ADType argument

* update README

* update make getting started example actually run Julia

* fix remove Float32 tests for inference tests

* update version

* add documentation publishing url

* fix wrong uuid for ForwardDiff

* Update CI.yml

* refactor use `sum` and `mean` instead of abusing `mapreduce`

* remove tests for `FullMonteCarlo`

* add tests for the `optimize` interface

* fix turn off Zygote tests for now

* remove unused function

* refactor change bijector field name, simplify STL estimator

* update documentation

* update STL documentation

* update STL documentation

* update location scale documentation

* fix README

* fix math in README

* add gradient to arguments of callback!, remove `gradient_norm` info

* fix math in README.md

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix type constraint in `ZygoteExt`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix import of `Random`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* refactor `__init__()`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* fix type constraint in definition of `value_and_gradient!`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* refactor `ZygoteExt`; use `only` instead of `first`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* refactor type constraint in `ReverseDiffExt`

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* refactor remove outdated debug mode macro

* fix remove outdated DEBUG mechanism

* fix LaTeX in README: `operatorname` is currently broken

* remove `SimpleUnPack` dependency

* fix LaTeX in docs and README

* add warning about forward-mode AD when using `LocationScale`

* fix documentation

* fix remove reamining use of `@unpack`

* Revert "remove `SimpleUnPack` dependency"

This reverts commit 29d7d27.

* Revert "fix remove reamining use of `@unpack`"

This reverts commit 8173744.

* fix documentation for `optimize`

* add specializations of `Optimise.destructure` for mean-field

* This fixes the poor performance of `ForwardDiff`
* This prevents the zero elements of the mean-field scale being extracted

* add test for `Optimisers.destructure` specializations

* add specialization of `rand` for meanfield resulting in faster AD

* add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian`

* update documentation

* fix type instability, bug in argument check in `LocationScale`

* add missing import bug

* refactor test, fix type bug in tests for `LocationScale`

* add missing compat entries

* fix missing package import in test

* add additional tests for sampling `LocationScale`

* fix bug in batch in-place `rand!` for `LocationScale`

* fix bug in inference test initialization

* add missing file

* fix remove use of  for 1.6

* refactor adjust inference test hyperparameters to be more robust

* refactor `optimize` to return `obj_state`, add warm start kwargs

* refactor make tests more robust, reduce amount of tests

* fix remove a cholesky in test model

* fix compat bounds, remove unused package

* bump compat for ADTypes 0.2

* fix broken LaTeX in README

* remove redundant use of PDMats in docs

* fix use `Cholesky` signature supported in 1.6

* revert custom variational families and docs

* remove doc action for now

* revert README for now

* refactor remove redundant `rng` argument to `ADVI`, improve docs

* fix wrong whitespace in tests

* refactor `estimate_gradient` to `estimate_gradient!`, add docs

* refactor add default `init` impl, update docs

* merge (manually) commit ff32ac6

* fix test for new interface, change interface for `optimize`, `advi`

* fix integer subtype error in documentation of advi

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix remove redundant argument for `advi`

* remove manifest

* refactor remove imports and use fully qualified names

* update documentation for `AbstractVariationalObjective`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* refactor use StableRNG instead of Random123

* refactor migrate to Test, re-enable x86 tests

* refactor remove inner constructor for `ADVI`

* fix swap `export`s and `include`s

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix doscs for `ADVI`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix use `FillArrays` in the test problems

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix `optimize` docs

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix improve argument names and docs for `optimize`

* fix tests to match new interface of `optimize`

* refactor move utility functions to new file

* fix docs for `optimize`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* refactor advi internal objective

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* refactor move `rng` to be an optional first argument

* fix docs for optimize

* add compat bounds to test dependencies

* update compat bound for `Optimisers`

* fix test compat

* fix remove `!` in callback

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix rng argument position in `advi`

* fix callback signature in `optimize`

* refactor reorganize test files and naming

* fix simplify description for `optimize`

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix remove redundant `Nothing` type signature for `maybe_init`

* fix remove "internal use" warning in documentation

* refactor change `estimate_gradient!` signature to be type stable

* add signature for computing `advi` over a fixed set of samples

* fix change test tolerance

* fix update documentation for `estimate_gradient!`

* refactor remove type constraint for variational parameters

* fix remove dead code

* add compat entry for stdlib

* add compat entry for stdlib in `test/`

* fix rng argument position in tests

* refactor change name of inference test

* fix documentation for `optimize`

* refactor rewrite the documentation for the global interfaces

* fix compat error

* fix documentation for `optimize` to be single line

* refactor remove begin end for one-liner

* refactor create unified interface for estimating objectives

* refactor unify interface for entropy estimator, fix advi docs

* fix STL estimator to use manually stopped gradients instead

* add inference test for a non-bijector model

* refactor add indirections to handle STL and bijectors in ADVI

* refactor split inference tests for advi+distributionsad

* refactor rename advi to repgradelbo and not use bijectors directly

* fix documentation for estimate_objective

* refactor add indirection in repgradelbo for interacting with `q`

* add TransformedDistribution support as extension

* Update src/objectives/elbo/repgradelbo.jl

Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>

* fix docstring for entropy estimator

* fix `reparam_with_entropy` specialization for bijectors

* enable Zygote for non-bijector tests

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
  • Loading branch information
4 people authored Dec 8, 2023
1 parent b0c4be3 commit 576259a
Show file tree
Hide file tree
Showing 28 changed files with 1,191 additions and 560 deletions.
62 changes: 48 additions & 14 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,37 +1,71 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.2.4"
version = "0.3.0"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[weakdeps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVIZygoteExt = "Zygote"
AdvancedVIBijectorsExt = "Bijectors"

[compat]
Bijectors = "0.11, 0.12, 0.13"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
ADTypes = "0.1, 0.2"
Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10.3"
ProgressMeter = "1.0.0"
Requires = "0.5, 1.0"
Enzyme = "0.11.7"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
Zygote = "0.6.63"
julia = "1.6"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "Test"]
45 changes: 45 additions & 0 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

module AdvancedVIBijectorsExt

if isdefined(Base, :get_extension)
using AdvancedVI
using Bijectors
using Random
else
using ..AdvancedVI
using ..Bijectors
using ..Random
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_base = q.dist
q_base_stop = q_stop.dist
base_samples = rand(rng, q_base, n_samples)
it = AdvancedVI.eachsample(base_samples)
sample_init = first(it)

samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(it, 1);
init=with_logabsdet_jacobian(transform, sample_init)
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)

entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
ent_est, base_samples, q_base, q_base_stop
)

entropy = entropy_base + logjac/n_samples
samples, entropy
end
end
26 changes: 26 additions & 0 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..Enzyme
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
y = f(θ)
DiffResults.value!(out, y)
∇θ = DiffResults.gradient(out)
fill!(∇θ, zero(T))
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
return out
end

end
29 changes: 29 additions & 0 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

module AdvancedVIForwardDiffExt

if isdefined(Base, :get_extension)
using ForwardDiff
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..ForwardDiff
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, θ)
else
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
end
ForwardDiff.gradient!(out, f, θ, config)
return out
end

end
23 changes: 23 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ReverseDiff
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ReverseDiff
end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
tp = ReverseDiff.GradientTape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end

end
24 changes: 24 additions & 0 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using Zygote
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..Zygote
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, θ)
∇θ = back(one(y))
DiffResults.value!(out, y)
DiffResults.gradient!(out, only(∇θ))
return out
end

end
Loading

0 comments on commit 576259a

Please sign in to comment.