From b49cf3e8cf2162706824735f0662559d6f838d55 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 14 Mar 2023 19:13:41 +0000 Subject: [PATCH 001/144] refactor ADVI, change gradient operation interface --- Project.toml | 1 + src/AdvancedVI.jl | 181 ++++++++++++++--------------------------- src/advi.jl | 47 ----------- src/estimators/advi.jl | 29 +++++++ src/utils.jl | 15 ++++ 5 files changed, 107 insertions(+), 166 deletions(-) create mode 100644 src/estimators/advi.jl diff --git a/Project.toml b/Project.toml index 28adc66a..71a2cbdc 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e203a13c..d42683d0 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -33,20 +33,12 @@ function __init__() export ZygoteAD function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ZygoteAD}, - q, - model, - θ::AbstractVector{<:Real}, + f::Function, + ::Type{<:ZygoteAD}, + λ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, - args... ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - y, back = Zygote.pullback(f, θ) + y, back = Zygote.pullback(f, λ) dy = first(back(1.0)) DiffResults.value!(out, y) DiffResults.gradient!(out, dy) @@ -58,21 +50,13 @@ function __init__() export ReverseDiffAD function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, - q, - model, - θ::AbstractVector{<:Real}, + f::Function, + ::Type{<:ReverseDiffAD}, + λ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, - args... ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - tp = AdvancedVI.tape(f, θ) - ReverseDiff.gradient!(out, tp, θ) + tp = AdvancedVI.tape(f, λ) + ReverseDiff.gradient!(out, tp, λ) return out end end @@ -81,26 +65,18 @@ function __init__() export EnzymeAD function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.EnzymeAD}, - q, - model, - θ::AbstractVector{<:Real}, + f::Function, + ::Type{<:EnzymeAD}, + λ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, - args... ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end # Use `Enzyme.ReverseWithPrimal` once it is released: # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(θ) + y = f(λ) DiffResults.value!(out, y) dy = DiffResults.gradient(out) fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) return out end end @@ -109,16 +85,8 @@ end export vi, ADVI, - ELBO, - elbo, TruncatedADAGrad, - DecayedADAGrad, - VariationalInference - -abstract type VariationalInference{AD} end - -getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD) -getADtype(::VariationalInference{AD}) where AD = AD + DecayedADAGrad abstract type VariationalObjective end @@ -126,13 +94,11 @@ const VariationalPosterior = Distribution{Multivariate, Continuous} """ - grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) + grad!(f, λ, out) -Computes the gradients used in `optimize!`. Default implementation is provided for +Computes the gradients of the objective f. Default implementation is provided for `VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. This implicitly also gives a default implementation of `optimize!`. - -Variance reduction techniques, e.g. control variates, should be implemented in this function. """ function grad! end @@ -157,51 +123,36 @@ function update end # default implementations function grad!( - vo, - alg::VariationalInference{<:ForwardDiffAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... + f::Function, + adtype::Type{<:ForwardDiffAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - f(θ_) = if (q isa Distribution) - - vo(alg, update(q, θ_), model, args...) - else - - vo(alg, q(θ_), model, args...) - end - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(typeof(alg)) + chunk_size = getchunksize(adtype) config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, θ) + ForwardDiff.GradientConfig(f, λ) else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) end - ForwardDiff.gradient!(out, f, θ, config) + ForwardDiff.gradient!(out, f, λ, config) end function grad!( - vo, - alg::VariationalInference{<:TrackerAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... + f::Function, + ::Type{<:TrackerAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - θ_tracked = Tracker.param(θ) - y = if (q isa Distribution) - - vo(alg, update(q, θ_tracked), model, args...) - else - - vo(alg, q(θ_tracked), model, args...) - end + λ_tracked = Tracker.param(λ) + y = f(λ_tracked) Tracker.back!(y, 1.0) DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(θ_tracked)) + DiffResults.gradient!(out, Tracker.grad(λ_tracked)) end +abstract type AbstractGradientEstimator end """ optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) @@ -210,61 +161,53 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer the steps. """ function optimize!( - vo, - alg::VariationalInference, - q, - model, - θ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad() + grad_estimator::AbstractGradientEstimator, + rebuild::Function, + ℓπ::Function, + n_max_iter::Int, + λ::AbstractVector{<:Real}; + optimizer = TruncatedADAGrad(), + rng = Random.GLOBAL_RNG ) - # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) - samples_per_step = alg.samples_per_step - max_iters = alg.max_iters - - num_params = length(θ) + obj_name = objective(grad_estimator) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) + if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) + @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) end - diff_result = DiffResults.GradientResult(θ) + grad_buf = DiffResults.GradientResult(λ) i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 - end + prog = ProgressMeter.Progress( + n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) - - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ + time_elapsed = @elapsed begin + for i = 1:n_max_iter + stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf) + + # apply update rule + Δλ = DiffResults.gradient(grad_buf) + Δλ = apply!(optimizer, λ, Δλ) + @. λ = λ - Δλ + + stat′ = (Δλ=norm(Δλ),) + stats = merge(stats, stat′) - AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) - - i += 1 + AdvancedVI.DEBUG && @debug "Step $i" stats... + pm_next!(prog, stats) + end end - - return θ + return λ end # objectives -include("objectives.jl") +include("estimators/advi.jl") # optimisers include("optimisers.jl") -# VI algorithms -include("advi.jl") - end # module diff --git a/src/advi.jl b/src/advi.jl index 7f9e7346..be9823db 100644 --- a/src/advi.jl +++ b/src/advi.jl @@ -50,50 +50,3 @@ function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = Truncate return θ end -# WITHOUT updating parameters inside ELBO -function (elbo::ELBO)( - rng::Random.AbstractRNG, - alg::ADVI, - q::VariationalPosterior, - logπ::Function, - num_samples -) - # 𝔼_q(z)[log p(xᵢ, z)] - # = ∫ log p(xᵢ, z) q(z) dz - # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) - # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] - - # 𝔼_q(z)[log q(z)] - # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ) [log q(f(ϕ))] - # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] - # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - - # Finally, the ELBO is given by - # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) - - # If f: supp(p(z | x)) → ℝ then - # ELBO = 𝔼[log p(x, z) - log q(z)] - # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) - # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - - # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - z, logjac = rand_and_logjac(rng, q) - res = (logπ(z) + logjac) / num_samples - - if q isa TransformedDistribution - res += entropy(q.dist) - else - res += entropy(q) - end - - for i = 2:num_samples - z, logjac = rand_and_logjac(rng, q) - res += (logπ(z) + logjac) / num_samples - end - - return res -end diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl new file mode 100644 index 00000000..c5a83957 --- /dev/null +++ b/src/estimators/advi.jl @@ -0,0 +1,29 @@ + +struct ADVI <: AbstractGradientEstimator + n_samples::Int +end + +objective(::ADVI) = "ELBO" + +function estimate_gradient!( + rng::Random.AbstractRNG, + estimator::ADVI, + λ::Vector{<:Real}, + rebuild::Function, + logπ::Function, + out::DiffResults.MutableDiffResult) + + n_samples = estimator.n_samples + + grad!(ADBackend(), λ, out) do λ′ + q = rebuild(λ′) + zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples) + + elbo = mapreduce(+, eachcol(zs)) do zᵢ + (logπ(zᵢ) + ∑logjac) + end / n_samples + -elbo + end + nelbo = DiffResults.value(out) + (elbo=-nelbo,) +end diff --git a/src/utils.jl b/src/utils.jl index bb4c1f18..87cc0856 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,3 +13,18 @@ function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDis y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) return y, logjac end + +function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int) + x = rand(rng, dist, n_samples) + return x, zero(eltype(x)) +end + +function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int) + x = rand(rng, dist.dist, n_samples) + y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) + return y, logjac +end + +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end From 88e0b79758c2f207b9d3c7120b469af837049fec Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 14 Mar 2023 19:56:47 +0000 Subject: [PATCH 002/144] remove unused file, remove unused dependency --- Project.toml | 1 - src/objectives.jl | 7 ------- 2 files changed, 8 deletions(-) delete mode 100644 src/objectives.jl diff --git a/Project.toml b/Project.toml index 71a2cbdc..28adc66a 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/objectives.jl b/src/objectives.jl deleted file mode 100644 index 5a6b61b0..00000000 --- a/src/objectives.jl +++ /dev/null @@ -1,7 +0,0 @@ -struct ELBO <: VariationalObjective end - -function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) - return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...) -end - -const elbo = ELBO() From c2fb3f8d08c15b16fa2e84a359b0d9bda3bf45b2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 15 Mar 2023 18:53:50 +0000 Subject: [PATCH 003/144] fix ADVI elbo computation more efficiently --- src/estimators/advi.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index c5a83957..44f65909 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -17,13 +17,17 @@ function estimate_gradient!( grad!(ADBackend(), λ, out) do λ′ q = rebuild(λ′) - zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples) - - elbo = mapreduce(+, eachcol(zs)) do zᵢ - (logπ(zᵢ) + ∑logjac) - end / n_samples + zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples) + + 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ + logπ(zᵢ) / n_samples + end + 𝔼logdetjac = ∑logdetjac/n_samples + + elbo = 𝔼logπ + 𝔼logdetjac -elbo end nelbo = DiffResults.value(out) (elbo=-nelbo,) end + From 83161fdf7fd18d9f686483da38174148ad305c9f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 15 Mar 2023 19:20:51 +0000 Subject: [PATCH 004/144] fix missing entropy regularization term --- src/estimators/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index 44f65909..ad45efbb 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -24,7 +24,7 @@ function estimate_gradient!( end 𝔼logdetjac = ∑logdetjac/n_samples - elbo = 𝔼logπ + 𝔼logdetjac + elbo = 𝔼logπ + 𝔼logdetjac + entropy(q) -elbo end nelbo = DiffResults.value(out) From efa810687738f4d297ff8b25aaadf28e37ba2080 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 18 Mar 2023 01:04:02 +0000 Subject: [PATCH 005/144] add LogDensityProblem interface --- Project.toml | 1 + src/AdvancedVI.jl | 5 +++-- src/estimators/advi.jl | 19 ++++++++++++++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 28adc66a..6ad4b689 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d42683d0..e1ac752f 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -7,6 +7,8 @@ using DocStringExtensions using ProgressMeter, LinearAlgebra +using LogDensityProblems + using ForwardDiff using Tracker @@ -163,7 +165,6 @@ the steps. function optimize!( grad_estimator::AbstractGradientEstimator, rebuild::Function, - ℓπ::Function, n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), @@ -187,7 +188,7 @@ function optimize!( # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf) + stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) # apply update rule Δλ = DiffResults.gradient(grad_buf) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index ad45efbb..5a8652b6 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -1,8 +1,22 @@ -struct ADVI <: AbstractGradientEstimator +struct ADVI{Tlogπ} <: AbstractGradientEstimator + ℓπ::Tlogπ n_samples::Int end +function ADVI(ℓπ, n_samples; kwargs...) + # ADVI requires gradients of log-likelihood + cap = LogDensityProblems.capabilities(ℓπ) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ADVI(Base.Fix1(LogDensityProblems.logdensity, ℓπ), n_samples) +end + objective(::ADVI) = "ELBO" function estimate_gradient!( @@ -10,7 +24,6 @@ function estimate_gradient!( estimator::ADVI, λ::Vector{<:Real}, rebuild::Function, - logπ::Function, out::DiffResults.MutableDiffResult) n_samples = estimator.n_samples @@ -20,7 +33,7 @@ function estimate_gradient!( zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples) 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ - logπ(zᵢ) / n_samples + estimator.ℓπ(zᵢ) / n_samples end 𝔼logdetjac = ∑logdetjac/n_samples From 4ae2fbfa832662b5adaa7e3d423cb312cb87b4c9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 18 Mar 2023 02:22:32 +0000 Subject: [PATCH 006/144] refactor use bijectors directly instead of transformed distributions This is to avoid having to reconstruct transformed distributions all the time. The direct use of bijectors also avoids going through lots of abstraction layers that could break. Instead, transformed distributions could be constructed only once when returing the VI result. --- src/estimators/advi.jl | 43 ++++++++++++++++++++++++++---------------- src/utils.jl | 30 ----------------------------- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index 5a8652b6..9784e924 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -1,22 +1,32 @@ -struct ADVI{Tlogπ} <: AbstractGradientEstimator +struct ADVI{Tlogπ, B <: Union{Function, Bijectors.Inverse{<:Bijectors.Bijector}}} <: AbstractGradientEstimator + # Automatic differentiation variational inference + # + # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). + # Automatic differentiation variational inference. + # Journal of machine learning research. + ℓπ::Tlogπ + b⁻¹::B n_samples::Int -end -function ADVI(ℓπ, n_samples; kwargs...) - # ADVI requires gradients of log-likelihood - cap = LogDensityProblems.capabilities(ℓπ) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) + function ADVI(prob, b⁻¹::B, n_samples; kwargs...) where {B <: Bijectors.Inverse{<:Bijectors.Bijector}} + # Could check whether the support of b⁻¹ and ℓπ match + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹, n_samples) end - ADVI(Base.Fix1(LogDensityProblems.logdensity, ℓπ), n_samples) end +ADVI(prob, n_samples; kwargs...) = ADVI(prob, identity, n_samples; kwargs...) + objective(::ADVI) = "ELBO" function estimate_gradient!( @@ -29,18 +39,19 @@ function estimate_gradient!( n_samples = estimator.n_samples grad!(ADBackend(), λ, out) do λ′ - q = rebuild(λ′) - zs, ∑logdetjac = rand_and_logjac(rng, q, estimator.n_samples) + q_η = rebuild(λ′) + ηs = rand(rng, q_η, estimator.n_samples) + + zs, ∑logdetjac = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηs) 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ estimator.ℓπ(zᵢ) / n_samples end 𝔼logdetjac = ∑logdetjac/n_samples - elbo = 𝔼logπ + 𝔼logdetjac + entropy(q) + elbo = 𝔼logπ + 𝔼logdetjac + entropy(q_η) -elbo end nelbo = DiffResults.value(out) (elbo=-nelbo,) end - diff --git a/src/utils.jl b/src/utils.jl index 87cc0856..e69de29b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,30 +0,0 @@ -using Distributions - -using Bijectors: Bijectors - - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution) - x = rand(rng, dist) - return x, zero(eltype(x)) -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution) - x = rand(rng, dist.dist) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int) - x = rand(rng, dist, n_samples) - return x, zero(eltype(x)) -end - -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int) - x = rand(rng, dist.dist, n_samples) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac -end - -function pm_next!(pm, stats::NamedTuple) - ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) -end From 1cadb51a011eeaf0b7d3e05aee7e45494bc2439a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 00:54:02 +0100 Subject: [PATCH 007/144] fix type restrictions --- src/estimators/advi.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index 9784e924..b4b3a9d0 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -1,5 +1,5 @@ -struct ADVI{Tlogπ, B <: Union{Function, Bijectors.Inverse{<:Bijectors.Bijector}}} <: AbstractGradientEstimator +struct ADVI{Tlogπ, B} <: AbstractGradientEstimator # Automatic differentiation variational inference # # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). @@ -33,7 +33,7 @@ function estimate_gradient!( rng::Random.AbstractRNG, estimator::ADVI, λ::Vector{<:Real}, - rebuild::Function, + rebuild, out::DiffResults.MutableDiffResult) n_samples = estimator.n_samples From 3474e8d2c97032f7a384d3b88cb7cc47bdae12f3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 00:54:23 +0100 Subject: [PATCH 008/144] remove unused file --- src/advi.jl | 52 ---------------------------------------------------- 1 file changed, 52 deletions(-) delete mode 100644 src/advi.jl diff --git a/src/advi.jl b/src/advi.jl deleted file mode 100644 index be9823db..00000000 --- a/src/advi.jl +++ /dev/null @@ -1,52 +0,0 @@ -using StatsFuns -using DistributionsAD -using Bijectors -using Bijectors: TransformedDistribution - - -""" -$(TYPEDEF) - -Automatic Differentiation Variational Inference (ADVI) with automatic differentiation -backend `AD`. - -# Fields - -$(TYPEDFIELDS) -""" -struct ADVI{AD} <: VariationalInference{AD} - "Number of samples used to estimate the ELBO in each optimization step." - samples_per_step::Int - "Maximum number of gradient steps." - max_iters::Int -end - -function ADVI(samples_per_step::Int=1, max_iters::Int=1000) - return ADVI{ADBackend()}(samples_per_step, max_iters) -end - -alg_str(::ADVI) = "ADVI" - -function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - # If `q` is a mean-field approx we use the specialized `update` function - if q isa Distribution - return update(q, θ) - else - # Otherwise we assume it's a mapping θ → q - return q(θ) - end -end - - -function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - - # `model` assumed to be callable z ↦ p(x, z) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - return θ -end - From 03a27679f98790f943b784d0f6282035ecdc8abe Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 03:19:03 +0100 Subject: [PATCH 009/144] fix use of with_logabsdet_jacobian --- src/estimators/advi.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl index b4b3a9d0..701ec1ef 100644 --- a/src/estimators/advi.jl +++ b/src/estimators/advi.jl @@ -10,7 +10,7 @@ struct ADVI{Tlogπ, B} <: AbstractGradientEstimator b⁻¹::B n_samples::Int - function ADVI(prob, b⁻¹::B, n_samples; kwargs...) where {B <: Bijectors.Inverse{<:Bijectors.Bijector}} + function ADVI(prob, b⁻¹, n_samples; kwargs...) # Could check whether the support of b⁻¹ and ℓπ match cap = LogDensityProblems.capabilities(prob) if cap === nothing @@ -42,14 +42,12 @@ function estimate_gradient!( q_η = rebuild(λ′) ηs = rand(rng, q_η, estimator.n_samples) - zs, ∑logdetjac = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηs) - - 𝔼logπ = mapreduce(+, eachcol(zs)) do zᵢ - estimator.ℓπ(zᵢ) / n_samples + 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηᵢ) + (estimator.ℓπ(zᵢ) + logdetjacᵢ) / n_samples end - 𝔼logdetjac = ∑logdetjac/n_samples - elbo = 𝔼logπ + 𝔼logdetjac + entropy(q_η) + elbo = 𝔼ℓ + entropy(q_η) -elbo end nelbo = DiffResults.value(out) From 09c44fb639864167e6548db89b7ad0196d04ddfc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 03:29:42 +0100 Subject: [PATCH 010/144] restructure project; move the main VI routine to its own file --- src/AdvancedVI.jl | 60 +++++++----------------------------------- src/vi.jl | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 50 deletions(-) create mode 100644 src/vi.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e1ac752f..d3612cb1 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -12,6 +12,13 @@ using LogDensityProblems using ForwardDiff using Tracker +using Bijectors: Bijectors + +using Distributions +using DistributionsAD + +using StatsFuns + const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") @@ -154,61 +161,14 @@ function grad!( DiffResults.gradient!(out, Tracker.grad(λ_tracked)) end +# estimators abstract type AbstractGradientEstimator end -""" - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) - -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. -""" -function optimize!( - grad_estimator::AbstractGradientEstimator, - rebuild::Function, - n_max_iter::Int, - λ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad(), - rng = Random.GLOBAL_RNG -) - obj_name = objective(grad_estimator) - - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) - end - - grad_buf = DiffResults.GradientResult(λ) - - i = 0 - prog = ProgressMeter.Progress( - n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) - - # add criterion? A running mean maybe? - time_elapsed = @elapsed begin - for i = 1:n_max_iter - stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) - - # apply update rule - Δλ = DiffResults.gradient(grad_buf) - Δλ = apply!(optimizer, λ, Δλ) - @. λ = λ - Δλ - - stat′ = (Δλ=norm(Δλ),) - stats = merge(stats, stat′) - - AdvancedVI.DEBUG && @debug "Step $i" stats... - pm_next!(prog, stats) - end - end - return λ -end - -# objectives include("estimators/advi.jl") # optimisers include("optimisers.jl") +include("vi.jl") + end # module diff --git a/src/vi.jl b/src/vi.jl new file mode 100644 index 00000000..aceb3f2d --- /dev/null +++ b/src/vi.jl @@ -0,0 +1,66 @@ + +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end + +""" + optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) + +Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute +the steps. +""" +function optimize( + grad_estimator::AbstractGradientEstimator, + rebuild::Function, + n_max_iter::Int, + λ::AbstractVector{<:Real}; + optimizer = TruncatedADAGrad(), + rng = Random.GLOBAL_RNG +) + obj_name = objective(grad_estimator) + + # TODO: really need a better way to warn the user about potentially + # not using the correct accumulator + if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) + # this message should only occurr once in the optimization process + @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) + end + + grad_buf = DiffResults.GradientResult(λ) + + i = 0 + prog = ProgressMeter.Progress( + n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) + + # add criterion? A running mean maybe? + time_elapsed = @elapsed begin + for i = 1:n_max_iter + stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) + + # apply update rule + Δλ = DiffResults.gradient(grad_buf) + Δλ = apply!(optimizer, λ, Δλ) + @. λ = λ - Δλ + + stat′ = (Δλ=norm(Δλ),) + stats = merge(stats, stat′) + + AdvancedVI.DEBUG && @debug "Step $i" stats... + pm_next!(prog, stats) + end + end + return λ +end + +# function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG) +# θ = copy(θ_init) +# optimize!(grad_estimator, rebuild, n_max_iter, λ, optimizer = optimizer, rng = rng) + +# # If `q` is a mean-field approx we use the specialized `update` function +# if q isa Distribution +# return update(q, θ) +# else +# # Otherwise we assume it's a mapping θ → q +# return q(θ) +# end +# end From b7407ceecd7f6c8e3fc7a4c443995347fd4659f5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 8 Jun 2023 03:31:35 +0100 Subject: [PATCH 011/144] remove redundant import --- src/AdvancedVI.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d3612cb1..32b114ba 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -12,8 +12,6 @@ using LogDensityProblems using ForwardDiff using Tracker -using Bijectors: Bijectors - using Distributions using DistributionsAD From 40401494ef032b1c9623856ed668373b251aaccb Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 9 Jun 2023 00:51:56 +0100 Subject: [PATCH 012/144] restructure project into more modular objective estimators --- src/AdvancedVI.jl | 8 ++--- src/estimators/advi.jl | 55 ------------------------------ src/objectives/elbo/advi_energy.jl | 35 +++++++++++++++++++ src/objectives/elbo/elbo.jl | 44 ++++++++++++++++++++++++ src/objectives/elbo/entropy.jl | 18 ++++++++++ src/vi.jl | 10 +++--- 6 files changed, 105 insertions(+), 65 deletions(-) delete mode 100644 src/estimators/advi.jl create mode 100644 src/objectives/elbo/advi_energy.jl create mode 100644 src/objectives/elbo/elbo.jl create mode 100644 src/objectives/elbo/entropy.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 32b114ba..dfb22930 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -95,8 +95,6 @@ export TruncatedADAGrad, DecayedADAGrad -abstract type VariationalObjective end - const VariationalPosterior = Distribution{Multivariate, Continuous} @@ -160,9 +158,11 @@ function grad!( end # estimators -abstract type AbstractGradientEstimator end +abstract type AbstractVariationalObjective end -include("estimators/advi.jl") +include("objectives/elbo/elbo.jl") +include("objectives/elbo/advi_energy.jl") +include("objectives/elbo/entropy.jl") # optimisers include("optimisers.jl") diff --git a/src/estimators/advi.jl b/src/estimators/advi.jl deleted file mode 100644 index 701ec1ef..00000000 --- a/src/estimators/advi.jl +++ /dev/null @@ -1,55 +0,0 @@ - -struct ADVI{Tlogπ, B} <: AbstractGradientEstimator - # Automatic differentiation variational inference - # - # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). - # Automatic differentiation variational inference. - # Journal of machine learning research. - - ℓπ::Tlogπ - b⁻¹::B - n_samples::Int - - function ADVI(prob, b⁻¹, n_samples; kwargs...) - # Could check whether the support of b⁻¹ and ℓπ match - cap = LogDensityProblems.capabilities(prob) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹, n_samples) - end -end - -ADVI(prob, n_samples; kwargs...) = ADVI(prob, identity, n_samples; kwargs...) - -objective(::ADVI) = "ELBO" - -function estimate_gradient!( - rng::Random.AbstractRNG, - estimator::ADVI, - λ::Vector{<:Real}, - rebuild, - out::DiffResults.MutableDiffResult) - - n_samples = estimator.n_samples - - grad!(ADBackend(), λ, out) do λ′ - q_η = rebuild(λ′) - ηs = rand(rng, q_η, estimator.n_samples) - - 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(estimator.b⁻¹, ηᵢ) - (estimator.ℓπ(zᵢ) + logdetjacᵢ) / n_samples - end - - elbo = 𝔼ℓ + entropy(q_η) - -elbo - end - nelbo = DiffResults.value(out) - (elbo=-nelbo,) -end diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl new file mode 100644 index 00000000..b27b752e --- /dev/null +++ b/src/objectives/elbo/advi_energy.jl @@ -0,0 +1,35 @@ + +struct ADVIEnergy{Tlogπ, B} <: AbstractEnergyEstimator + # Automatic differentiation variational inference + # + # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). + # Automatic differentiation variational inference. + # Journal of machine learning research. + + ℓπ::Tlogπ + b⁻¹::B + + function ADVIEnergy(prob, b⁻¹) + # Could check whether the support of b⁻¹ and ℓπ match + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹) + end +end + +ADVIEnergy(prob) = ADVIEnergy(prob, identity) + +function (energy::ADVIEnergy)(q, ηs::AbstractMatrix) + n_samples = size(ηs, 2) + mapreduce(+, eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(energy.b⁻¹, ηᵢ) + (energy.ℓπ(zᵢ) + logdetjacᵢ) / n_samples + end +end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl new file mode 100644 index 00000000..2954ae8e --- /dev/null +++ b/src/objectives/elbo/elbo.jl @@ -0,0 +1,44 @@ + +abstract type AbstractEnergyEstimator end +abstract type AbstractEntropyEstimator end + +struct ELBO{EnergyEst <: AbstractEnergyEstimator, + EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + # Evidence Lower Bound + # + # Jordan, Michael I., et al. + # "An introduction to variational methods for graphical models." + # Machine learning 37 (1999): 183-233. + + energy_estimator::EnergyEst + entropy_estimator::EntropyEst + n_samples::Int +end + +Base.string(::ELBO) = "ELBO" + +function ADVI(ℓπ, b⁻¹, n_samples::Int) + ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) +end + +function estimate_gradient!( + rng::Random.AbstractRNG, + objective::ELBO, + λ::Vector{<:Real}, + rebuild, + out::DiffResults.MutableDiffResult) + + n_samples = objective.n_samples + + grad!(ADBackend(), λ, out) do λ′ + q_η = rebuild(λ′) + ηs = rand(rng, q_η, n_samples) + + 𝔼ℓ = objective.energy_estimator(q_η, ηs) + ℍ = objective.entropy_estimator(q_η, ηs) + elbo = 𝔼ℓ + ℍ + -elbo + end + nelbo = DiffResults.value(out) + (elbo=-nelbo,) +end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl new file mode 100644 index 00000000..d7fb7054 --- /dev/null +++ b/src/objectives/elbo/entropy.jl @@ -0,0 +1,18 @@ + +struct ClosedFormEntropy <: AbstractEntropyEstimator +end + +function (::ClosedFormEntropy)(q, ηs::AbstractMatrix) + entropy(q) +end + +struct MonteCarloEntropy <: AbstractEntropyEstimator +end + +function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) + n_samples = size(ηs, 2) + mapreduce(+, eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) / n_samples + end +end + diff --git a/src/vi.jl b/src/vi.jl index aceb3f2d..4bf4595f 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -10,32 +10,30 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer the steps. """ function optimize( - grad_estimator::AbstractGradientEstimator, + objective::AbstractVariationalObjective, rebuild::Function, n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG ) - obj_name = objective(grad_estimator) - # TODO: really need a better way to warn the user about potentially # not using the correct accumulator if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) # this message should only occurr once in the optimization process - @info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ) + @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end grad_buf = DiffResults.GradientResult(λ) i = 0 prog = ProgressMeter.Progress( - n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[]) + n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, grad_buf) + stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) # apply update rule Δλ = DiffResults.gradient(grad_buf) From 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 9 Jun 2023 01:18:02 +0100 Subject: [PATCH 013/144] migrate to AbstractDifferentiation --- Project.toml | 3 +- src/AdvancedVI.jl | 101 ++---------------------------------- src/objectives/elbo/elbo.jl | 10 ++-- src/vi.jl | 8 ++- 4 files changed, 13 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index e73037ec..6964c135 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.3" [deps] +AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -15,7 +16,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] Bijectors = "0.11, 0.12" @@ -27,7 +27,6 @@ ProgressMeter = "1.0.0" Requires = "0.5, 1.0" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" julia = "1.6" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index dfb22930..809d86c6 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -9,14 +9,16 @@ using ProgressMeter, LinearAlgebra using LogDensityProblems -using ForwardDiff -using Tracker - using Distributions using DistributionsAD using StatsFuns +using ForwardDiff +import AbstractDifferentiation as AD + +value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...) + const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") @@ -35,58 +37,6 @@ function __init__() Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ZygoteAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - y, back = Zygote.pullback(f, λ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ReverseDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - tp = AdvancedVI.tape(f, λ) - ReverseDiff.gradient!(out, tp, λ) - return out - end - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - export EnzymeAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:EnzymeAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(λ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) - return out - end - end end export @@ -97,16 +47,6 @@ export const VariationalPosterior = Distribution{Multivariate, Continuous} - -""" - grad!(f, λ, out) - -Computes the gradients of the objective f. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. -""" -function grad! end - """ vi(model, alg::VariationalInference) vi(model, alg::VariationalInference, q::VariationalPosterior) @@ -126,37 +66,6 @@ function vi end function update end -# default implementations -function grad!( - f::Function, - adtype::Type{<:ForwardDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(adtype) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, λ) - else - ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) - end - ForwardDiff.gradient!(out, f, λ, config) -end - -function grad!( - f::Function, - ::Type{<:TrackerAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - λ_tracked = Tracker.param(λ) - y = f(λ_tracked) - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(λ_tracked)) -end - # estimators abstract type AbstractVariationalObjective end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 2954ae8e..213cc725 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -22,15 +22,14 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) end function estimate_gradient!( + adbackend::AD.AbstractBackend, rng::Random.AbstractRNG, objective::ELBO, λ::Vector{<:Real}, - rebuild, - out::DiffResults.MutableDiffResult) + rebuild) n_samples = objective.n_samples - - grad!(ADBackend(), λ, out) do λ′ + nelbo, grad = value_and_gradient(λ; adbackend) do λ′ q_η = rebuild(λ′) ηs = rand(rng, q_η, n_samples) @@ -39,6 +38,5 @@ function estimate_gradient!( elbo = 𝔼ℓ + ℍ -elbo end - nelbo = DiffResults.value(out) - (elbo=-nelbo,) + first(grad), (elbo=-nelbo,) end diff --git a/src/vi.jl b/src/vi.jl index 4bf4595f..7b7858b8 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -15,7 +15,8 @@ function optimize( n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), - rng = Random.GLOBAL_RNG + rng = Random.default_rng(), + adbackend = AD.ForwardDiffBackend() ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -24,8 +25,6 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end - grad_buf = DiffResults.GradientResult(λ) - i = 0 prog = ProgressMeter.Progress( n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) @@ -33,10 +32,9 @@ function optimize( # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild) # apply update rule - Δλ = DiffResults.gradient(grad_buf) Δλ = apply!(optimizer, λ, Δλ) @. λ = λ - Δλ From 93a16d8bc6aac9725081ea4c414ffd9343e6e79e Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 00:42:36 +0100 Subject: [PATCH 014/144] add location scale pre-packaged variational family, add functors --- Project.toml | 2 ++ src/AdvancedVI.jl | 19 +++++++++++++---- src/distributions/location_scale.jl | 33 +++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 src/distributions/location_scale.jl diff --git a/Project.toml b/Project.toml index 6964c135..88342f19 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" @@ -16,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] Bijectors = "0.11, 0.12" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 809d86c6..8c33f74a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -2,6 +2,8 @@ module AdvancedVI using Random: Random +using Functors + using Distributions, DistributionsAD, Bijectors using DocStringExtensions @@ -13,8 +15,9 @@ using Distributions using DistributionsAD using StatsFuns +import StatsBase: entropy -using ForwardDiff +using ForwardDiff, Tracker import AbstractDifferentiation as AD value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...) @@ -40,13 +43,18 @@ function __init__() end export - vi, + optimize, + ELBO, ADVI, + ADVIEnergy, + ClosedFormEntropy, + MonteCarloEntropy, + LocationScale, + FullRankGaussian, + MeanFieldGaussian, TruncatedADAGrad, DecayedADAGrad -const VariationalPosterior = Distribution{Multivariate, Continuous} - """ vi(model, alg::VariationalInference) vi(model, alg::VariationalInference, q::VariationalPosterior) @@ -73,6 +81,9 @@ include("objectives/elbo/elbo.jl") include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") +# Variational Families +include("distributions/location_scale.jl") + # optimisers include("optimisers.jl") diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl new file mode 100644 index 00000000..3aba53c5 --- /dev/null +++ b/src/distributions/location_scale.jl @@ -0,0 +1,33 @@ + +LocationScale(μ::LinearAlgebra.AbstractVector, + L::Union{<: LinearAlgebra.AbstractTriangular, + <: LinearAlgebra.Diagonal}, + q₀::Distributions.ContinuousMultivariateDistribution) = + transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) + +function location_scale_entropy( + q₀::Distributions.ContinuousMultivariateDistribution, + locscale_bijector::Bijectors.ComposedFunction) +end + +function entropy(q_trans::MultivariateTransformed{ + <: Distributions.ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}) + q_base = q_trans.dist + scale = q_trans.transform.inner.a + entropy(q_base) + first(logabsdet(scale)) +end + +function FullRankGaussian(μ::AbstractVector, + L::LinearAlgebra.AbstractTriangular) + q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) + LocationScale(μ, L, q₀) +end + +function MeanFieldGaussian(μ::AbstractVector, + L::LinearAlgebra.Diagonal) + q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) + LocationScale(μ, L, q₀) +end From 2b6e9ebed556dd67bb9325a5b04228637e1e03df Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:04:19 +0100 Subject: [PATCH 015/144] Revert "migrate to AbstractDifferentiation" This reverts commit 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691. --- Project.toml | 2 +- src/AdvancedVI.jl | 101 ++++++++++++++++++++++++++++++++++-- src/objectives/elbo/elbo.jl | 10 ++-- src/vi.jl | 8 +-- 4 files changed, 108 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 88342f19..9a3303f5 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,6 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.3" [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -29,6 +28,7 @@ ProgressMeter = "1.0.0" Requires = "0.5, 1.0" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" +Tracker = "0.2.3" julia = "1.6" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8c33f74a..116bb63c 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -11,17 +11,15 @@ using ProgressMeter, LinearAlgebra using LogDensityProblems +using ForwardDiff +using Tracker + using Distributions using DistributionsAD using StatsFuns import StatsBase: entropy -using ForwardDiff, Tracker -import AbstractDifferentiation as AD - -value_and_gradient(f, xs...; adbackend) = AD.value_and_gradient(adbackend, f, xs...) - const PROGRESS = Ref(true) function turnprogress(switch::Bool) @info("[AdvancedVI]: global PROGRESS is set as $switch") @@ -40,6 +38,58 @@ function __init__() Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("compat/zygote.jl") + export ZygoteAD + + function AdvancedVI.grad!( + f::Function, + ::Type{<:ZygoteAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + y, back = Zygote.pullback(f, λ) + dy = first(back(1.0)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, dy) + return out + end + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("compat/reversediff.jl") + export ReverseDiffAD + + function AdvancedVI.grad!( + f::Function, + ::Type{<:ReverseDiffAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + tp = AdvancedVI.tape(f, λ) + ReverseDiff.gradient!(out, tp, λ) + return out + end + end + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("compat/enzyme.jl") + export EnzymeAD + + function AdvancedVI.grad!( + f::Function, + ::Type{<:EnzymeAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + # Use `Enzyme.ReverseWithPrimal` once it is released: + # https://github.com/EnzymeAD/Enzyme.jl/pull/598 + y = f(λ) + DiffResults.value!(out, y) + dy = DiffResults.gradient(out) + fill!(dy, 0) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) + return out + end + end end export @@ -55,6 +105,16 @@ export TruncatedADAGrad, DecayedADAGrad + +""" + grad!(f, λ, out) + +Computes the gradients of the objective f. Default implementation is provided for +`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. +This implicitly also gives a default implementation of `optimize!`. +""" +function grad! end + """ vi(model, alg::VariationalInference) vi(model, alg::VariationalInference, q::VariationalPosterior) @@ -74,6 +134,37 @@ function vi end function update end +# default implementations +function grad!( + f::Function, + adtype::Type{<:ForwardDiffAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) + # Set chunk size and do ForwardMode. + chunk_size = getchunksize(adtype) + config = if chunk_size == 0 + ForwardDiff.GradientConfig(f, λ) + else + ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) + end + ForwardDiff.gradient!(out, f, λ, config) +end + +function grad!( + f::Function, + ::Type{<:TrackerAD}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) + λ_tracked = Tracker.param(λ) + y = f(λ_tracked) + Tracker.back!(y, 1.0) + + DiffResults.value!(out, Tracker.data(y)) + DiffResults.gradient!(out, Tracker.grad(λ_tracked)) +end + # estimators abstract type AbstractVariationalObjective end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 213cc725..2954ae8e 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -22,14 +22,15 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) end function estimate_gradient!( - adbackend::AD.AbstractBackend, rng::Random.AbstractRNG, objective::ELBO, λ::Vector{<:Real}, - rebuild) + rebuild, + out::DiffResults.MutableDiffResult) n_samples = objective.n_samples - nelbo, grad = value_and_gradient(λ; adbackend) do λ′ + + grad!(ADBackend(), λ, out) do λ′ q_η = rebuild(λ′) ηs = rand(rng, q_η, n_samples) @@ -38,5 +39,6 @@ function estimate_gradient!( elbo = 𝔼ℓ + ℍ -elbo end - first(grad), (elbo=-nelbo,) + nelbo = DiffResults.value(out) + (elbo=-nelbo,) end diff --git a/src/vi.jl b/src/vi.jl index 7b7858b8..4bf4595f 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -15,8 +15,7 @@ function optimize( n_max_iter::Int, λ::AbstractVector{<:Real}; optimizer = TruncatedADAGrad(), - rng = Random.default_rng(), - adbackend = AD.ForwardDiffBackend() + rng = Random.GLOBAL_RNG ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -25,6 +24,8 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end + grad_buf = DiffResults.GradientResult(λ) + i = 0 prog = ProgressMeter.Progress( n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) @@ -32,9 +33,10 @@ function optimize( # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter - Δλ, stats = estimate_gradient!(adbackend, rng, objective, λ, rebuild) + stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) # apply update rule + Δλ = DiffResults.gradient(grad_buf) Δλ = apply!(optimizer, λ, Δλ) @. λ = λ - Δλ From 1bfec36961c437cf000234bd29504fd49848d676 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:41:25 +0100 Subject: [PATCH 016/144] fix use optimized MvNormal specialization, add logpdf for Loc.Scale. --- Project.toml | 2 + src/AdvancedVI.jl | 23 +++++++----- src/distributions/location_scale.jl | 57 +++++++++++++++++++---------- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/Project.toml b/Project.toml index 9a3303f5..38a5026a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,10 +7,12 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 116bb63c..d5a06fce 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,20 +4,23 @@ using Random: Random using Functors -using Distributions, DistributionsAD, Bijectors using DocStringExtensions -using ProgressMeter, LinearAlgebra +using ProgressMeter +using LinearAlgebra +using LinearAlgebra: AbstractTriangular using LogDensityProblems using ForwardDiff using Tracker -using Distributions -using DistributionsAD +using FillArrays +using PDMats +using Distributions, DistributionsAD +using Distributions: ContinuousMultivariateDistribution +using Bijectors -using StatsFuns import StatsBase: entropy const PROGRESS = Ref(true) @@ -29,7 +32,6 @@ end const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) include("ad.jl") -include("utils.jl") using Requires function __init__() @@ -116,9 +118,9 @@ This implicitly also gives a default implementation of `optimize!`. function grad! end """ - vi(model, alg::VariationalInference) - vi(model, alg::VariationalInference, q::VariationalPosterior) - vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) + optimize(model, alg::VariationalInference) + optimize(model, alg::VariationalInference, q::VariationalPosterior) + optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) Constructs the variational posterior from the `model` and performs the optimization following the configuration of the given `VariationalInference` instance. @@ -130,7 +132,7 @@ following the configuration of the given `VariationalInference` instance. - `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` - `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior """ -function vi end +function optimize end function update end @@ -178,6 +180,7 @@ include("distributions/location_scale.jl") # optimisers include("optimisers.jl") +include("utils.jl") include("vi.jl") end # module diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 3aba53c5..365ae15e 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,33 +1,52 @@ -LocationScale(μ::LinearAlgebra.AbstractVector, - L::Union{<: LinearAlgebra.AbstractTriangular, - <: LinearAlgebra.Diagonal}, - q₀::Distributions.ContinuousMultivariateDistribution) = - transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) +function LocationScale(μ::AbstractVector, + L::Union{<: AbstractTriangular, + <: Diagonal}, + q₀::ContinuousMultivariateDistribution) + @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) + transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) +end function location_scale_entropy( - q₀::Distributions.ContinuousMultivariateDistribution, + q₀::ContinuousMultivariateDistribution, locscale_bijector::Bijectors.ComposedFunction) end -function entropy(q_trans::MultivariateTransformed{ - <: Distributions.ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}) +function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}) q_base = q_trans.dist scale = q_trans.transform.inner.a entropy(q_base) + first(logabsdet(scale)) end -function FullRankGaussian(μ::AbstractVector, - L::LinearAlgebra.AbstractTriangular) - q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) - LocationScale(μ, L, q₀) +function logpdf(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}, + z::AbstractVector) + q_base = q_trans.dist + reparam = q_trans.transform + scale = q_trans.transform.inner.a + η = inverse(reparam)(z) + logpdf(q_base, η) - first(logabsdet(scale)) +end + +function FullRankGaussian(μ::AbstractVector{T}, + L::AbstractTriangular{T,S}) where {T <: Real, S} + @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) + n_dims = length(μ) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), + PDMats.ScalMat{T}(n_dims, one(T))) + LocationScale(μ, L, q_base) end -function MeanFieldGaussian(μ::AbstractVector, - L::LinearAlgebra.Diagonal) - q₀ = MvNormal(zeros(eltype(μ), length(μ)), one(eltype(μ))) - LocationScale(μ, L, q₀) +function MeanFieldGaussian(μ::AbstractVector{T}, + L::Diagonal{T,V}) where {T <: Real, V} + @assert (length(μ) == size(L,1)) + n_dims = length(μ) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), + PDMats.ScalMat{T}(n_dims, one(T))) + LocationScale(μ, L, q_base) end From 1003606283efd6b8cf340e74dced65d8ea72b296 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:52:53 +0100 Subject: [PATCH 017/144] remove dead code --- src/distributions/location_scale.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 365ae15e..1f7bad85 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -7,11 +7,6 @@ function LocationScale(μ::AbstractVector, transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) end -function location_scale_entropy( - q₀::ContinuousMultivariateDistribution, - locscale_bijector::Bijectors.ComposedFunction) -end - function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, <: Bijectors.ComposedFunction{ <: Bijectors.Shift, From 60a9987ed259b906da9cdd6e38ed33102497f389 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 21:56:30 +0100 Subject: [PATCH 018/144] fix location-scale logpdf - Full Monte Carlo ELBO estimation now works. I checked. --- src/AdvancedVI.jl | 3 ++- src/distributions/location_scale.jl | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index d5a06fce..9b9d3ab2 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -21,7 +21,8 @@ using Distributions, DistributionsAD using Distributions: ContinuousMultivariateDistribution using Bijectors -import StatsBase: entropy +using StatsBase +using StatsBase: entropy const PROGRESS = Ref(true) function turnprogress(switch::Bool) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 1f7bad85..dd9b5f2a 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -7,20 +7,22 @@ function LocationScale(μ::AbstractVector, transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) end -function entropy(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}) +function StatsBase.entropy( + q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}) q_base = q_trans.dist scale = q_trans.transform.inner.a entropy(q_base) + first(logabsdet(scale)) end -function logpdf(q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}, - z::AbstractVector) +function Distributions.logpdf( + q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, + <: Bijectors.ComposedFunction{ + <: Bijectors.Shift, + <: Bijectors.Scale}}, + z::AbstractVector) q_base = q_trans.dist reparam = q_trans.transform scale = q_trans.transform.inner.a From cd84f02898d7cf82c530f98a91579f0b01935f33 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 22:21:22 +0100 Subject: [PATCH 019/144] add sticking-the-landing (STL) estimator --- src/objectives/elbo/elbo.jl | 36 ++++++++++++++++++++++++---------- src/objectives/elbo/entropy.jl | 35 ++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 2954ae8e..343581d8 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -21,23 +21,39 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) end +function (elbo::ELBO)(q_η::ContinuousMultivariateDistribution; + rng = Random.default_rng(), + n_samples::Int = elbo.n_samples, + q_η_entropy::ContinuousMultivariateDistribution = q_η) + ηs = rand(rng, q_η, n_samples) + 𝔼ℓ = elbo.energy_estimator(q_η, ηs) + ℍ = elbo.entropy_estimator(q_η_entropy, ηs) + 𝔼ℓ + ℍ +end + function estimate_gradient!( rng::Random.AbstractRNG, - objective::ELBO, + elbo::ELBO{EnergyEst, EntropyEst}, λ::Vector{<:Real}, rebuild, - out::DiffResults.MutableDiffResult) - - n_samples = objective.n_samples + out::DiffResults.MutableDiffResult) where {EnergyEst <: AbstractEnergyEstimator, + EntropyEst <: AbstractEntropyEstimator} + + # Gradient-stopping for computing the sticking-the-landing control variate + q_η_stop = if EntropyEst isa MonteCarloEntropy{true} + rebuild(λ) + else + nothing + end grad!(ADBackend(), λ, out) do λ′ q_η = rebuild(λ′) - ηs = rand(rng, q_η, n_samples) - - 𝔼ℓ = objective.energy_estimator(q_η, ηs) - ℍ = objective.entropy_estimator(q_η, ηs) - elbo = 𝔼ℓ + ℍ - -elbo + q_η_entropy = if EntropyEst isa MonteCarloEntropy{true} + q_η_stop + else + q_η + end + -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) end nelbo = DiffResults.value(out) (elbo=-nelbo,) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index d7fb7054..8efb7c71 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,13 +1,38 @@ -struct ClosedFormEntropy <: AbstractEntropyEstimator -end +struct ClosedFormEntropy <: AbstractEntropyEstimator end -function (::ClosedFormEntropy)(q, ηs::AbstractMatrix) +function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end -struct MonteCarloEntropy <: AbstractEntropyEstimator -end +struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end + +MonteCarloEntropy() = MonteCarloEntropy{false}() + +""" + Sticking the Landing Control Variate + + # Explanation + + This eatimator forms a control variate of the form of + + c(z) = 𝔼-logq(z) + logq(z) = ℍ[q] - logq(z) + + Adding this to the closed-form entropy ELBO estimator yields: + + ELBO - c(z) = 𝔼logπ(z) + ℍ[q] - c(z) = 𝔼logπ(z) - logq(z), + + which has the same expectation, but lower variance when π ≈ q, + and higher variance when π ≉ q. + + # Reference + + Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud. + "Sticking the landing: Simple, lower-variance gradient estimators for + variational inference." + Advances in Neural Information Processing Systems 30 (2017). +""" +StickingTheLandingEntropy() = MonteCarloEntropy{true}() function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) From 768641b1979f4e63125780e53f48e21794bbcdd2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 22:41:50 +0100 Subject: [PATCH 020/144] migrate to Optimisers.jl --- Project.toml | 1 + src/AdvancedVI.jl | 11 +++-------- src/vi.jl | 27 ++++++++++++++++----------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 38a5026a..ba807698 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9b9d3ab2..5a02501b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,6 +4,8 @@ using Random: Random using Functors +using Optimisers + using DocStringExtensions using ProgressMeter @@ -12,8 +14,7 @@ using LinearAlgebra: AbstractTriangular using LogDensityProblems -using ForwardDiff -using Tracker +using ForwardDiff, Tracker using FillArrays using PDMats @@ -24,12 +25,6 @@ using Bijectors using StatsBase using StatsBase: entropy -const PROGRESS = Ref(true) -function turnprogress(switch::Bool) - @info("[AdvancedVI]: global PROGRESS is set as $switch") - PROGRESS[] = switch -end - const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) include("ad.jl") diff --git a/src/vi.jl b/src/vi.jl index 4bf4595f..6c8b26d1 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -10,12 +10,13 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer the steps. """ function optimize( - objective::AbstractVariationalObjective, - rebuild::Function, + objective ::AbstractVariationalObjective, + rebuild, n_max_iter::Int, - λ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad(), - rng = Random.GLOBAL_RNG + λ ::AbstractVector{<:Real}; + optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), + rng ::Random.AbstractRNG = Random.GLOBAL_RNG, + progress ::Bool = true ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -24,21 +25,25 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end + optstate = Optimisers.init(optimizer, λ) grad_buf = DiffResults.GradientResult(λ) i = 0 prog = ProgressMeter.Progress( - n_max_iter; desc="[$(string(objective))] Optimizing...", barlen=0, enabled=PROGRESS[]) + n_max_iter; + desc = "[$(string(objective))] Optimizing...", + barlen = 0, + enabled = progress, + showspeed = true) # add criterion? A running mean maybe? time_elapsed = @elapsed begin for i = 1:n_max_iter stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) - - # apply update rule - Δλ = DiffResults.gradient(grad_buf) - Δλ = apply!(optimizer, λ, Δλ) - @. λ = λ - Δλ + g = DiffResults.gradient(grad_buf) + + optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) + Optimisers.subtract!(λ, Δλ) stat′ = (Δλ=norm(Δλ),) stats = merge(stats, stat′) From ca02fa315486a0977327f3e2824cd87b40b1908a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 10 Jun 2023 22:42:38 +0100 Subject: [PATCH 021/144] remove execution time measurement (replace later with somethin else) --- src/vi.jl | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/vi.jl b/src/vi.jl index 6c8b26d1..e5062def 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -36,23 +36,20 @@ function optimize( enabled = progress, showspeed = true) - # add criterion? A running mean maybe? - time_elapsed = @elapsed begin - for i = 1:n_max_iter - stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) - g = DiffResults.gradient(grad_buf) + for i = 1:n_max_iter + stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + g = DiffResults.gradient(grad_buf) - optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) - Optimisers.subtract!(λ, Δλ) + optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) + Optimisers.subtract!(λ, Δλ) - stat′ = (Δλ=norm(Δλ),) - stats = merge(stats, stat′) + stat′ = (Δλ=norm(Δλ),) + stats = merge(stats, stat′) - AdvancedVI.DEBUG && @debug "Step $i" stats... + AdvancedVI.DEBUG && @debug "Step $i" stats... pm_next!(prog, stats) - end end - return λ + λ end # function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG) From a48377f016c82461000ba10c35803a5181f4b4a9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 21:47:22 +0100 Subject: [PATCH 022/144] fix use multiple dispatch for deciding whether to stop entropy grad. --- src/objectives/elbo/elbo.jl | 21 +++++++-------------- src/objectives/elbo/entropy.jl | 4 ++++ src/vi.jl | 4 ++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index 343581d8..cebd7d82 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -15,6 +15,8 @@ struct ELBO{EnergyEst <: AbstractEnergyEstimator, n_samples::Int end +skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator) + Base.string(::ELBO) = "ELBO" function ADVI(ℓπ, b⁻¹, n_samples::Int) @@ -33,28 +35,19 @@ end function estimate_gradient!( rng::Random.AbstractRNG, - elbo::ELBO{EnergyEst, EntropyEst}, + elbo::ELBO, λ::Vector{<:Real}, rebuild, - out::DiffResults.MutableDiffResult) where {EnergyEst <: AbstractEnergyEstimator, - EntropyEst <: AbstractEntropyEstimator} + out::DiffResults.MutableDiffResult) # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = if EntropyEst isa MonteCarloEntropy{true} - rebuild(λ) - else - nothing - end + q_η_stop = skip_entropy_gradient(elbo) ? rebuild(λ) : nothing grad!(ADBackend(), λ, out) do λ′ q_η = rebuild(λ′) - q_η_entropy = if EntropyEst isa MonteCarloEntropy{true} - q_η_stop - else - q_η - end + q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) end nelbo = DiffResults.value(out) - (elbo=-nelbo,) + out, (elbo=-nelbo,) end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 8efb7c71..50f498d6 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,6 +5,8 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end +skip_entropy_gradient(::ClosedFormEntropy) = false + struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end MonteCarloEntropy() = MonteCarloEntropy{false}() @@ -34,6 +36,8 @@ MonteCarloEntropy() = MonteCarloEntropy{false}() """ StickingTheLandingEntropy() = MonteCarloEntropy{true}() +skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding + function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) mapreduce(+, eachcol(ηs)) do ηᵢ diff --git a/src/vi.jl b/src/vi.jl index e5062def..8b8fe14f 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -37,8 +37,8 @@ function optimize( showspeed = true) for i = 1:n_max_iter - stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) - g = DiffResults.gradient(grad_buf) + grad_buf, stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + g = DiffResults.gradient(grad_buf) optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) Optimisers.subtract!(λ, Δλ) From 0b40ccf6ef10e6ebef9d6372e407731bb4dc2ca0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 22:16:30 +0100 Subject: [PATCH 023/144] add termination decision, callback arguments --- src/vi.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/vi.jl b/src/vi.jl index 8b8fe14f..1a4d57ec 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -16,7 +16,9 @@ function optimize( λ ::AbstractVector{<:Real}; optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), rng ::Random.AbstractRNG = Random.GLOBAL_RNG, - progress ::Bool = true + progress ::Bool = true, + callback! = nothing, + terminate = (args...) -> false, ) # TODO: really need a better way to warn the user about potentially # not using the correct accumulator @@ -28,6 +30,7 @@ function optimize( optstate = Optimisers.init(optimizer, λ) grad_buf = DiffResults.GradientResult(λ) + q = rebuild(λ) i = 0 prog = ProgressMeter.Progress( n_max_iter; @@ -43,11 +46,22 @@ function optimize( optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) Optimisers.subtract!(λ, Δλ) - stat′ = (Δλ=norm(Δλ),) + stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g)) stats = merge(stats, stat′) + q = rebuild(λ) + + if !isnothing(callback!) + stat′ = callback!(q, stats) + stats = !isnothing(stat′) ? merge(stat′, stats) : stats + end AdvancedVI.DEBUG && @debug "Step $i" stats... pm_next!(prog, stats) + + # Termination decision is work in progress + if terminate(rng, q, objective, stats) + break + end end λ end From 21db3fb842d226148ee23b758c0756e332132066 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 22:35:48 +0100 Subject: [PATCH 024/144] add Base.show to modules --- src/objectives/elbo/advi_energy.jl | 2 ++ src/objectives/elbo/elbo.jl | 6 +++++- src/objectives/elbo/entropy.jl | 4 ++++ src/vi.jl | 1 - 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl index b27b752e..078a157e 100644 --- a/src/objectives/elbo/advi_energy.jl +++ b/src/objectives/elbo/advi_energy.jl @@ -26,6 +26,8 @@ end ADVIEnergy(prob) = ADVIEnergy(prob, identity) +Base.show(io::IO, energy::ADVIEnergy) = print(io, "ADVIEnergy()") + function (energy::ADVIEnergy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) mapreduce(+, eachcol(ηs)) do ηᵢ diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index cebd7d82..b26516d9 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -17,7 +17,11 @@ end skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator) -Base.string(::ELBO) = "ELBO" +Base.show(io::IO, elbo::ELBO) = print( + io, + "ELBO(energy_estimator=$(elbo.energy_estimator), " * + "entropy_estimator=$(elbo.entropy_estimator)), " * + "n_samples=$(elbo.n_samples))") function ADVI(ℓπ, b⁻¹, n_samples::Int) ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 50f498d6..ddeb64a9 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -11,6 +11,8 @@ struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end MonteCarloEntropy() = MonteCarloEntropy{false}() +Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()") + """ Sticking the Landing Control Variate @@ -38,6 +40,8 @@ StickingTheLandingEntropy() = MonteCarloEntropy{true}() skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding +Base.show(io::IO, entropy::MonteCarloEntropy{true}) = print(io, "StickingTheLandingEntropy()") + function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) mapreduce(+, eachcol(ηs)) do ηᵢ diff --git a/src/vi.jl b/src/vi.jl index 1a4d57ec..605464b6 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -34,7 +34,6 @@ function optimize( i = 0 prog = ProgressMeter.Progress( n_max_iter; - desc = "[$(string(objective))] Optimizing...", barlen = 0, enabled = progress, showspeed = true) From 25c51b4796b2e550d1ee9747e5ccbf81a48aff38 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 12 Jun 2023 23:03:25 +0100 Subject: [PATCH 025/144] add interface calling `restructure`, rename rebuild -> restructure --- src/objectives/elbo/elbo.jl | 6 ++-- src/vi.jl | 61 ++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl index b26516d9..b3bad3c0 100644 --- a/src/objectives/elbo/elbo.jl +++ b/src/objectives/elbo/elbo.jl @@ -41,14 +41,14 @@ function estimate_gradient!( rng::Random.AbstractRNG, elbo::ELBO, λ::Vector{<:Real}, - rebuild, + restructure, out::DiffResults.MutableDiffResult) # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(elbo) ? rebuild(λ) : nothing + q_η_stop = skip_entropy_gradient(elbo) ? restructure(λ) : nothing grad!(ADBackend(), λ, out) do λ′ - q_η = rebuild(λ′) + q_η = restructure(λ′) q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) end diff --git a/src/vi.jl b/src/vi.jl index 605464b6..f1f4bc25 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -11,9 +11,9 @@ the steps. """ function optimize( objective ::AbstractVariationalObjective, - rebuild, - n_max_iter::Int, - λ ::AbstractVector{<:Real}; + restructure, + λ ::AbstractVector{<:Real}, + n_max_iter::Int; optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), rng ::Random.AbstractRNG = Random.GLOBAL_RNG, progress ::Bool = true, @@ -30,50 +30,47 @@ function optimize( optstate = Optimisers.init(optimizer, λ) grad_buf = DiffResults.GradientResult(λ) - q = rebuild(λ) - i = 0 - prog = ProgressMeter.Progress( - n_max_iter; - barlen = 0, - enabled = progress, - showspeed = true) + prog = ProgressMeter.Progress(n_max_iter; + barlen = 0, + enabled = progress, + showspeed = true) + stats = Vector{NamedTuple}(undef, n_max_iter) - for i = 1:n_max_iter - grad_buf, stats = estimate_gradient!(rng, objective, λ, rebuild, grad_buf) + for t = 1:n_max_iter + grad_buf, stat = estimate_gradient!(rng, objective, λ, restructure, grad_buf) g = DiffResults.gradient(grad_buf) optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) Optimisers.subtract!(λ, Δλ) stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g)) - stats = merge(stats, stat′) - q = rebuild(λ) + stat = merge(stat, stat′) + q = restructure(λ) if !isnothing(callback!) - stat′ = callback!(q, stats) - stats = !isnothing(stat′) ? merge(stat′, stats) : stats + stat′ = callback!(q, stat) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - AdvancedVI.DEBUG && @debug "Step $i" stats... - pm_next!(prog, stats) + AdvancedVI.DEBUG && @debug "Step $i" stat... + + pm_next!(prog, stat) + stats[t] = stat # Termination decision is work in progress - if terminate(rng, q, objective, stats) + if terminate(rng, q, objective, stat) + stats = stats[1:t] break end end - λ + λ, stats end -# function vi(grad_estimator, q, θ_init; optimizer = TruncatedADAGrad(), rng = Random.GLOBAL_RNG) -# θ = copy(θ_init) -# optimize!(grad_estimator, rebuild, n_max_iter, λ, optimizer = optimizer, rng = rng) - -# # If `q` is a mean-field approx we use the specialized `update` function -# if q isa Distribution -# return update(q, θ) -# else -# # Otherwise we assume it's a mapping θ → q -# return q(θ) -# end -# end +function optimize(objective::AbstractVariationalObjective, + q, + n_max_iter::Int; + kwargs...) + λ, restructure = Optimisers.destructure(q) + λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...) + restructure(λ), stats +end From fc200462e0a6929ca580d6cabaad27afd179b30f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:33:47 +0100 Subject: [PATCH 026/144] add estimator state interface, add control variate interface to ADVI --- src/AdvancedVI.jl | 12 ++++++- src/objectives/elbo/advi.jl | 64 +++++++++++++++++++++++++++++++++++++ src/objectives/elbo/elbo.jl | 57 --------------------------------- src/vi.jl | 22 +++++++------ 4 files changed, 88 insertions(+), 67 deletions(-) create mode 100644 src/objectives/elbo/advi.jl delete mode 100644 src/objectives/elbo/elbo.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5a02501b..f2eb2317 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -166,7 +166,17 @@ end # estimators abstract type AbstractVariationalObjective end -include("objectives/elbo/elbo.jl") +function estimate_gradient end + +abstract type AbstractEnergyEstimator end +abstract type AbstractEntropyEstimator end +abstract type AbstractControlVariate end + +init(::Nothing) = nothing + +update(::Nothing, ::Nothing) = (nothing, nothing) + +include("objectives/elbo/advi.jl") include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl new file mode 100644 index 00000000..66e5f320 --- /dev/null +++ b/src/objectives/elbo/advi.jl @@ -0,0 +1,64 @@ + +struct ADVI{EnergyEst <: AbstractEnergyEstimator, + EntropyEst <: AbstractEntropyEstimator, + ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective + energy_estimator::EnergyEst + entropy_estimator::EntropyEst + control_variate::ControlVar + n_samples::Int +end + +skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) + +init(advi::ADVI) = init(advi.control_variate) + +Base.show(io::IO, advi::ADVI) = print( + io, + "ADVI(energy_estimator=$(advi.energy_estimator), " * + "entropy_estimator=$(advi.entropy_estimator)), " * + "n_samples=$(advi.n_samples))") + +function ADVI(energy_estimator::AbstractEnergyEstimator, + entropy_estimator::AbstractEntropyEstimator, + n_samples::Int) + ADVI(energy_estimator, entropy_estimator, nothing, n_samples) +end + +function ADVI(ℓπ, b⁻¹, n_samples::Int) + ADVI(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) +end + +function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; + rng ::Random.AbstractRNG = Random.default_rng(), + n_samples ::Int = advi.n_samples, + ηs ::AbstractMatrix = rand(rng, q_η, n_samples), + q_η_entropy::ContinuousMultivariateDistribution = q_η) + 𝔼ℓ = advi.energy_estimator(q_η, ηs) + ℍ = advi.entropy_estimator(q_η_entropy, ηs) + 𝔼ℓ + ℍ +end + +function estimate_gradient( + rng::Random.AbstractRNG, + advi::ADVI, + est_state, + λ::Vector{<:Real}, + restructure, + out::DiffResults.MutableDiffResult) + + # Gradient-stopping for computing the sticking-the-landing control variate + q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing + + grad!(ADBackend(), λ, out) do λ′ + q_η = restructure(λ′) + q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η + -advi(q_η; rng, q_η_entropy) + end + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + est_state, stat′ = update(advi.control_variate, est_state) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat + + out, est_state, stat +end diff --git a/src/objectives/elbo/elbo.jl b/src/objectives/elbo/elbo.jl deleted file mode 100644 index b3bad3c0..00000000 --- a/src/objectives/elbo/elbo.jl +++ /dev/null @@ -1,57 +0,0 @@ - -abstract type AbstractEnergyEstimator end -abstract type AbstractEntropyEstimator end - -struct ELBO{EnergyEst <: AbstractEnergyEstimator, - EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective - # Evidence Lower Bound - # - # Jordan, Michael I., et al. - # "An introduction to variational methods for graphical models." - # Machine learning 37 (1999): 183-233. - - energy_estimator::EnergyEst - entropy_estimator::EntropyEst - n_samples::Int -end - -skip_entropy_gradient(elbo::ELBO) = skip_entropy_gradient(elbo.entropy_estimator) - -Base.show(io::IO, elbo::ELBO) = print( - io, - "ELBO(energy_estimator=$(elbo.energy_estimator), " * - "entropy_estimator=$(elbo.entropy_estimator)), " * - "n_samples=$(elbo.n_samples))") - -function ADVI(ℓπ, b⁻¹, n_samples::Int) - ELBO(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) -end - -function (elbo::ELBO)(q_η::ContinuousMultivariateDistribution; - rng = Random.default_rng(), - n_samples::Int = elbo.n_samples, - q_η_entropy::ContinuousMultivariateDistribution = q_η) - ηs = rand(rng, q_η, n_samples) - 𝔼ℓ = elbo.energy_estimator(q_η, ηs) - ℍ = elbo.entropy_estimator(q_η_entropy, ηs) - 𝔼ℓ + ℍ -end - -function estimate_gradient!( - rng::Random.AbstractRNG, - elbo::ELBO, - λ::Vector{<:Real}, - restructure, - out::DiffResults.MutableDiffResult) - - # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(elbo) ? restructure(λ) : nothing - - grad!(ADBackend(), λ, out) do λ′ - q_η = restructure(λ′) - q_η_entropy = skip_entropy_gradient(elbo) ? q_η_stop : q_η - -elbo(q_η; rng, n_samples=elbo.n_samples, q_η_entropy) - end - nelbo = DiffResults.value(out) - out, (elbo=-nelbo,) -end diff --git a/src/vi.jl b/src/vi.jl index f1f4bc25..ebb246be 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -27,8 +27,9 @@ function optimize( @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) end - optstate = Optimisers.init(optimizer, λ) - grad_buf = DiffResults.GradientResult(λ) + opt_state = Optimisers.init(optimizer, λ) + est_state = init(objective) + grad_buf = DiffResults.GradientResult(λ) prog = ProgressMeter.Progress(n_max_iter; barlen = 0, @@ -37,22 +38,25 @@ function optimize( stats = Vector{NamedTuple}(undef, n_max_iter) for t = 1:n_max_iter - grad_buf, stat = estimate_gradient!(rng, objective, λ, restructure, grad_buf) - g = DiffResults.gradient(grad_buf) + stat = (iteration=t,) - optstate, Δλ = Optimisers.apply!(optimizer, optstate, λ, g) + grad_buf, est_state, stat′ = estimate_gradient(rng, objective, est_state, λ, restructure, grad_buf) + g = DiffResults.gradient(grad_buf) + stat = merge(stat, stat′) + + opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g) Optimisers.subtract!(λ, Δλ) + stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) + stat = merge(stat, stat′) - stat′ = (Δλ=norm(Δλ), gradient_norm=norm(g)) - stat = merge(stat, stat′) - q = restructure(λ) + q = restructure(λ) if !isnothing(callback!) stat′ = callback!(q, stat) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - AdvancedVI.DEBUG && @debug "Step $i" stat... + AdvancedVI.DEBUG && @debug "Step $t" stat... pm_next!(prog, stat) stats[t] = stat From 6faa807f067ff77856c307ef4baa11865616deae Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:39:05 +0100 Subject: [PATCH 027/144] fix `show(advi)` to show control variate --- src/objectives/elbo/advi.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 66e5f320..de2c683b 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -15,7 +15,8 @@ init(advi::ADVI) = init(advi.control_variate) Base.show(io::IO, advi::ADVI) = print( io, "ADVI(energy_estimator=$(advi.energy_estimator), " * - "entropy_estimator=$(advi.entropy_estimator)), " * + "entropy_estimator=$(advi.entropy_estimator), " * + (!isnothing(advi.control_variate) ? "control_variate=$(advi.control_variate), " : "") * "n_samples=$(advi.n_samples))") function ADVI(energy_estimator::AbstractEnergyEstimator, From 7095d276f5947b855289099a0ce56f2106c8b16c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:39:45 +0100 Subject: [PATCH 028/144] fix simplify `show(advi.control_variate)` --- src/objectives/elbo/advi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index de2c683b..dc2962ee 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -16,7 +16,7 @@ Base.show(io::IO, advi::ADVI) = print( io, "ADVI(energy_estimator=$(advi.energy_estimator), " * "entropy_estimator=$(advi.entropy_estimator), " * - (!isnothing(advi.control_variate) ? "control_variate=$(advi.control_variate), " : "") * + "control_variate=$(advi.control_variate), " * "n_samples=$(advi.n_samples))") function ADVI(energy_estimator::AbstractEnergyEstimator, From 9169ae262f8ac289d8e7355f8642584e18da3614 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 13 Jun 2023 00:51:48 +0100 Subject: [PATCH 029/144] fix type piracy by wrapping location-scale bijected distribution --- src/distributions/location_scale.jl | 67 ++++++++++++++++------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index dd9b5f2a..f3c95d0c 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,41 +1,51 @@ -function LocationScale(μ::AbstractVector, - L::Union{<: AbstractTriangular, - <: Diagonal}, - q₀::ContinuousMultivariateDistribution) - @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) +import Base: rand, _rand! + +struct LocationScale{ReparamMvDist <: Bijectors.TransformedDistribution} <: ContinuousMultivariateDistribution + q_trans::ReparamMvDist + + function LocationScale(μ::AbstractVector, + L::Union{<: AbstractTriangular, + <: Diagonal}, + q₀::ContinuousMultivariateDistribution) + @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) + q_trans = transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) + new{typeof(q_trans)}(q_trans) + end + + function LocationScale(q_trans::Bijectors.TransformedDistribution) + new{typeof(q_trans)}(q_trans) + end end -function StatsBase.entropy( - q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}) - q_base = q_trans.dist - scale = q_trans.transform.inner.a +Functors.@functor LocationScale + +Base.length(q::LocationScale) = length(q.q_trans) +Base.size(q::LocationScale) = size(q.q_trans) + +function StatsBase.entropy(q::LocationScale) + q_base = q.q_trans.dist + scale = q.q_trans.transform.inner.a entropy(q_base) + first(logabsdet(scale)) end -function Distributions.logpdf( - q_trans::MultivariateTransformed{<: ContinuousMultivariateDistribution, - <: Bijectors.ComposedFunction{ - <: Bijectors.Shift, - <: Bijectors.Scale}}, - z::AbstractVector) - q_base = q_trans.dist - reparam = q_trans.transform - scale = q_trans.transform.inner.a - η = inverse(reparam)(z) - logpdf(q_base, η) - first(logabsdet(scale)) -end + +Distributions.logpdf(q::LocationScale, z::AbstractVector) = logpdf(q.q_trans, z) + +_logpdf(q::LocationScale, y::AbstractVector) = _logpdf(q.q_trans, y) + +rand(q::LocationScale) = rand(q.q_trans) + +rand(rng::Random.AbstractRNG, q::LocationScale, num_samples::Int) = rand(rng, q.q_trans, num_samples) + +_rand!(rng::Random.AbstractRNG, q::LocationScale, x::AbstractVector{<:Real}) = _rand!(rng, q.q_trans, x) + function FullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T,S}) where {T <: Real, S} @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), - PDMats.ScalMat{T}(n_dims, one(T))) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) LocationScale(μ, L, q_base) end @@ -43,7 +53,6 @@ function MeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T,V}) where {T <: Real, V} @assert (length(μ) == size(L,1)) n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), - PDMats.ScalMat{T}(n_dims, one(T))) + q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) LocationScale(μ, L, q_base) end From 3db73011a430fb3aa5830264be687d860410f483 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 26 Jun 2023 23:01:27 +0100 Subject: [PATCH 030/144] remove old AdvancedVI custom optimizers --- Project.toml | 1 + src/AdvancedVI.jl | 15 +++----- src/optimisers.jl | 94 ----------------------------------------------- src/vi.jl | 11 +----- 4 files changed, 8 insertions(+), 113 deletions(-) delete mode 100644 src/optimisers.jl diff --git a/Project.toml b/Project.toml index ba807698..d2708915 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] Bijectors = "0.11, 0.12" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index f2eb2317..76c6d859 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,9 +1,12 @@ + module AdvancedVI -using Random: Random +using UnPack -using Functors +import Random: AbstractRNG, default_rng +import Distributions: logpdf, _logpdf, rand, _rand!, _rand! +using Functors using Optimisers using DocStringExtensions @@ -31,11 +34,6 @@ include("ad.jl") using Requires function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin - apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) - Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) - Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) - end @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("compat/zygote.jl") export ZygoteAD @@ -183,9 +181,6 @@ include("objectives/elbo/entropy.jl") # Variational Families include("distributions/location_scale.jl") -# optimisers -include("optimisers.jl") - include("utils.jl") include("vi.jl") diff --git a/src/optimisers.jl b/src/optimisers.jl deleted file mode 100644 index 8077f98c..00000000 --- a/src/optimisers.jl +++ /dev/null @@ -1,94 +0,0 @@ -const ϵ = 1e-8 - -""" - TruncatedADAGrad(η=0.1, τ=1.0, n=100) - -Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - τ: constant scale factor - - n: number of previous gradient norms to use in the scaling. -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. - -[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E). -""" -mutable struct TruncatedADAGrad - eta::Float64 - tau::Float64 - n::Int - - iters::IdDict - acc::IdDict -end - -function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) - TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) -end - -function apply!(o::TruncatedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - τ = o.tau - - g² = get!( - o.acc, - x, - [zeros(T, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(Δ)), 1} - i = get!(o.iters, x, 1)::Int - - # Example: suppose i = 12 and o.n = 10 - idx = mod(i - 1, o.n) + 1 # => idx = 2 - - # set the current - @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ - - # TODO: make more efficient and stable - s = sum(g²) - - # increment - o.iters[x] += 1 - - # TODO: increment (but "truncate") - # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1 - - @. Δ *= η / (τ + sqrt(s) + ϵ) -end - -""" - DecayedADAGrad(η=0.1, pre=1.0, post=0.9) - -Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - pre: weight of new gradient norm - - post: weight of histroy of gradient norms -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -mutable struct DecayedADAGrad - eta::Float64 - pre::Float64 - post::Float64 - - acc::IdDict -end - -DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) - -function apply!(o::DecayedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) - @. acc = o.post * acc + o.pre * Δ^2 - @. Δ *= η / (√acc + ϵ) -end diff --git a/src/vi.jl b/src/vi.jl index ebb246be..842f187e 100644 --- a/src/vi.jl +++ b/src/vi.jl @@ -14,19 +14,12 @@ function optimize( restructure, λ ::AbstractVector{<:Real}, n_max_iter::Int; - optimizer ::Optimisers.AbstractRule = TruncatedADAGrad(), - rng ::Random.AbstractRNG = Random.GLOBAL_RNG, + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + rng ::AbstractRNG = default_rng(), progress ::Bool = true, callback! = nothing, terminate = (args...) -> false, ) - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$(string(objective))] Should only be seen once: optimizer created for θ" objectid(λ) - end - opt_state = Optimisers.init(optimizer, λ) est_state = init(objective) grad_buf = DiffResults.GradientResult(λ) From e6a082aadbd3fa92e60fedf5373f2efbb1875ecc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 26 Jun 2023 23:47:04 +0100 Subject: [PATCH 031/144] fix Location Scale to not depend on Bijectors --- src/distributions/location_scale.jl | 101 +++++++++++++++++----------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index f3c95d0c..c46b5111 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,58 +1,79 @@ -import Base: rand, _rand! - -struct LocationScale{ReparamMvDist <: Bijectors.TransformedDistribution} <: ContinuousMultivariateDistribution - q_trans::ReparamMvDist - - function LocationScale(μ::AbstractVector, - L::Union{<: AbstractTriangular, - <: Diagonal}, - q₀::ContinuousMultivariateDistribution) +struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution + location::L + scale ::S + dist ::D + epsilon ::R + + function VILocationScale(μ::AbstractVector{<:Real}, + L::Union{<:AbstractTriangular{<:Real}, + <:Diagonal{<:Real}}, + q_base::ContinuousUnivariateDistribution, + epsilon::Real) + # Restricting all the arguments to have the same types creates problems + # with dual-variable-based AD frameworks. @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - q_trans = transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(L)) - new{typeof(q_trans)}(q_trans) - end - - function LocationScale(q_trans::Bijectors.TransformedDistribution) - new{typeof(q_trans)}(q_trans) + new{typeof(μ), typeof(L), typeof(q_base), typeof(epsilon)}(μ, L, q_base, epsilon) end end -Functors.@functor LocationScale +Functors.@functor VILocationScale (location, scale) -Base.length(q::LocationScale) = length(q.q_trans) -Base.size(q::LocationScale) = size(q.q_trans) +Base.length(q::VILocationScale) = length(q.location) +Base.size(q::VILocationScale) = size(q.location) -function StatsBase.entropy(q::LocationScale) - q_base = q.q_trans.dist - scale = q.q_trans.transform.inner.a - entropy(q_base) + first(logabsdet(scale)) +function StatsBase.entropy(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + n_dims*entropy(dist) + first(logabsdet(scale)) end +function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) +end -Distributions.logpdf(q::LocationScale, z::AbstractVector) = logpdf(q.q_trans, z) - -_logpdf(q::LocationScale, y::AbstractVector) = _logpdf(q.q_trans, y) +function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) +end -rand(q::LocationScale) = rand(q.q_trans) +function rand(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims) + location +end -rand(rng::Random.AbstractRNG, q::LocationScale, num_samples::Int) = rand(rng, q.q_trans, num_samples) +function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims, num_samples) .+ location +end -_rand!(rng::Random.AbstractRNG, q::LocationScale, x::AbstractVector{<:Real}) = _rand!(rng, q.q_trans, x) +function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x .= scale*x + return x += location +end +function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x *= scale + return x += location +end -function FullRankGaussian(μ::AbstractVector{T}, - L::AbstractTriangular{T,S}) where {T <: Real, S} - @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) - LocationScale(μ, L, q_base) +function VIFullRankGaussian(μ::AbstractVector{T}, + L::AbstractTriangular{T}, + epsilon::Real = eps(T)) where {T <: Real} + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base, epsilon) end -function MeanFieldGaussian(μ::AbstractVector{T}, - L::Diagonal{T,V}) where {T <: Real, V} - @assert (length(μ) == size(L,1)) - n_dims = length(μ) - q_base = MvNormal(FillArrays.Zeros{T}(n_dims), PDMats.ScalMat{T}(n_dims, one(T))) - LocationScale(μ, L, q_base) +function VIMeanFieldGaussian(μ::AbstractVector{T}, + L::Diagonal{T}, + epsilon::Real = eps(T)) where {T <: Real} + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base, epsilon) end From a034ebdec0e42d63211fe8e1c23d4b4e714a30bb Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 00:50:33 +0100 Subject: [PATCH 032/144] fix RNG namespace --- src/objectives/elbo/advi.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index dc2962ee..311a94f3 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -30,9 +30,9 @@ function ADVI(ℓπ, b⁻¹, n_samples::Int) end function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; - rng ::Random.AbstractRNG = Random.default_rng(), - n_samples ::Int = advi.n_samples, - ηs ::AbstractMatrix = rand(rng, q_η, n_samples), + rng ::AbstractRNG = default_rng(), + n_samples ::Int = advi.n_samples, + ηs ::AbstractMatrix = rand(rng, q_η, n_samples), q_η_entropy::ContinuousMultivariateDistribution = q_η) 𝔼ℓ = advi.energy_estimator(q_η, ηs) ℍ = advi.entropy_estimator(q_η_entropy, ηs) @@ -40,7 +40,7 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; end function estimate_gradient( - rng::Random.AbstractRNG, + rng::AbstractRNG, advi::ADVI, est_state, λ::Vector{<:Real}, From e19abd3d06291090f45b4b8b118e7be3003343c5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:08:46 +0100 Subject: [PATCH 033/144] fix location scale logpdf bug --- src/distributions/location_scale.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index c46b5111..c1803ffe 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -23,19 +23,19 @@ Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) function StatsBase.entropy(q::VILocationScale) - @unpack location, scale, dist = q + @unpack location, scale, dist = q n_dims = length(location) n_dims*entropy(dist) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) + first(logabsdet(scale)) + mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) From 680c1864ecfe2a2867e9f48fe4bbf1ca37065aa3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:12:19 +0100 Subject: [PATCH 034/144] add Accessors dependency --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index d2708915..add1e391 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.3" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" From 4c6cabf688af0552a307c22b821901cc792676be Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:12:44 +0100 Subject: [PATCH 035/144] add location scale, autodiff tests --- test/ad.jl | 22 +++++++++++++++++++++ test/distributions.jl | 45 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 32 ++++++++---------------------- 3 files changed, 75 insertions(+), 24 deletions(-) create mode 100644 test/ad.jl create mode 100644 test/distributions.jl diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 00000000..c084165c --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,22 @@ + +using ReTest +using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote +using AdvancedVI: grad! + +@testset "ad" begin + @testset "$(string(adsymbol))" for adsymbol ∈ [ + :forwarddiff, :reversediff, :tracker, :enzyme, :zygote] + D = 10 + A = randn(D, D) + λ = randn(D) + AdvancedVI.setadbackend(adsymbol) + grad_buf = DiffResults.GradientResult(λ) + AdvancedVI.grad!(AdvancedVI.ADBackend(), λ, grad_buf) do λ′ + λ′'*A*λ′ / 2 + end + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A')*λ/2 + @test f ≈ λ'*A*λ / 2 + end +end diff --git a/test/distributions.jl b/test/distributions.jl new file mode 100644 index 00000000..ab9617aa --- /dev/null +++ b/test/distributions.jl @@ -0,0 +1,45 @@ + +using ReTest +using Distributions +using Distributions: _logpdf +using LinearAlgebra +using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian + +@testset "distributions" begin + @testset "$(string(covtype)) Gaussian $(realtype)" for + covtype = [:diagonal, :fullrank], + realtype = [Float32, Float64] + + realtype = Float64 + ϵ = 1e-2 + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(realtype, n_dims) + L₀ = randn(realtype, n_dims, n_dims) + Σ = if covtype == :fullrank + Σ = (L₀*L₀' + ϵ*I) |> Hermitian + else + Diagonal(exp.(randn(realtype, n_dims))) + end + + L = cholesky(Σ).L + q = if covtype == :fullrank + VIFullRankGaussian(μ, L |> LowerTriangular) + else + VIMeanFieldGaussian(μ, L |> Diagonal) + end + q_true = MvNormal(μ, Σ) + + z = randn(n_dims) + @test logpdf(q, z) ≈ logpdf(q_true, z) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) + @test entropy(q) ≈ entropy(q_true) + + z_samples = rand(q, n_montecarlo) + threesigma = L + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a305c25e..44074197 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,28 +1,12 @@ -using Test -using Distributions, DistributionsAD -using AdvancedVI -include("optimisers.jl") +using ReTest: @testset, @test +#using Random +#using Statistics +#using Distributions, DistributionsAD -target = MvNormal(ones(2)) -logπ(z) = logpdf(target, z) -advi = ADVI(10, 1000) +println("Environment variables for testing") +println(ENV) -# Using a function z ↦ q(⋅∣z) -getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -q = vi(logπ, advi, getq, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 - -# OR: implement `update` and pass a `Distribution` -function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) -end - -q0 = TuringDiagMvNormal(zeros(2), ones(2)) -q = vi(logπ, advi, q0, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +include("ad.jl") +include("distributions.jl") From 06db2f02233e8e4e6010be6473ea7f356742a4a3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 03:15:03 +0100 Subject: [PATCH 036/144] add Accessors import statement --- src/AdvancedVI.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 76c6d859..5800cd93 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,7 @@ module AdvancedVI -using UnPack +using UnPack, Accessors import Random: AbstractRNG, default_rng import Distributions: logpdf, _logpdf, rand, _rand!, _rand! @@ -179,6 +179,7 @@ include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") # Variational Families + include("distributions/location_scale.jl") include("utils.jl") From 12de2bda787624b862772fc0b4fa55729ebb6ff9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 20:12:48 +0100 Subject: [PATCH 037/144] remove optimiser tests --- test/optimisers.jl | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 test/optimisers.jl diff --git a/test/optimisers.jl b/test/optimisers.jl deleted file mode 100644 index fae652ed..00000000 --- a/test/optimisers.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Random, Test, LinearAlgebra, ForwardDiff -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! - -θ = randn(10, 10) -@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 - x = rand(10) - Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end - From bbb2cc649fce6caddb751d0e5743d2fc2a814ad2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 20:12:59 +0100 Subject: [PATCH 038/144] refactor slightly generalize the distribution tests for the future --- test/distributions.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/distributions.jl b/test/distributions.jl index ab9617aa..07b3efdf 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -6,8 +6,9 @@ using LinearAlgebra using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian @testset "distributions" begin - @testset "$(string(covtype)) Gaussian $(realtype)" for - covtype = [:diagonal, :fullrank], + @testset "$(string(covtype)) $(basedist) $(realtype)" for + basedist = [:gaussian], + covtype = [:meanfield, :fullrank], realtype = [Float32, Float64] realtype = Float64 @@ -24,12 +25,14 @@ using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian end L = cholesky(Σ).L - q = if covtype == :fullrank + q = if covtype == :fullrank && basedist == :gaussian VIFullRankGaussian(μ, L |> LowerTriangular) - else + elseif covtype == :meanfield && basedist == :gaussian VIMeanFieldGaussian(μ, L |> Diagonal) end - q_true = MvNormal(μ, Σ) + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end z = randn(n_dims) @test logpdf(q, z) ≈ logpdf(q_true, z) From 197484655468ec5bab362380fb58d896a082b150 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:10:51 +0100 Subject: [PATCH 039/144] migrate to SimpleUnPack, migrate to ADTypes --- Project.toml | 3 +- src/AdvancedVI.jl | 150 ++++++++++---------------------------- src/ad.jl | 46 ------------ src/compat/enzyme.jl | 19 ++++- src/compat/reversediff.jl | 21 +++--- src/compat/zygote.jl | 16 +++- test/ad.jl | 14 ++-- 7 files changed, 90 insertions(+), 179 deletions(-) delete mode 100644 src/ad.jl diff --git a/Project.toml b/Project.toml index 93e3a52a..2fcc845e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" version = "0.2.4" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -18,10 +19,10 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] Bijectors = "0.11, 0.12, 0.13" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5800cd93..573f7179 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,8 @@ module AdvancedVI -using UnPack, Accessors +using SimpleUnPack: @unpack +using Accessors import Random: AbstractRNG, default_rng import Distributions: logpdf, _logpdf, rand, _rand!, _rand! @@ -17,6 +18,8 @@ using LinearAlgebra: AbstractTriangular using LogDensityProblems +using ADTypes +using ADTypes: AbstractADType using ForwardDiff, Tracker using FillArrays @@ -30,78 +33,19 @@ using StatsBase: entropy const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) -include("ad.jl") - using Requires function __init__() @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ZygoteAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - y, back = Zygote.pullback(f, λ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:ReverseDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - tp = AdvancedVI.tape(f, λ) - ReverseDiff.gradient!(out, tp, λ) - return out - end end @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("compat/enzyme.jl") - export EnzymeAD - - function AdvancedVI.grad!( - f::Function, - ::Type{<:EnzymeAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(λ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) - return out - end end end -export - optimize, - ELBO, - ADVI, - ADVIEnergy, - ClosedFormEntropy, - MonteCarloEntropy, - LocationScale, - FullRankGaussian, - MeanFieldGaussian, - TruncatedADAGrad, - DecayedADAGrad - - """ grad!(f, λ, out) @@ -111,55 +55,7 @@ This implicitly also gives a default implementation of `optimize!`. """ function grad! end -""" - optimize(model, alg::VariationalInference) - optimize(model, alg::VariationalInference, q::VariationalPosterior) - optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) - -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. - -# Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior -""" -function optimize end - -function update end - -# default implementations -function grad!( - f::Function, - adtype::Type{<:ForwardDiffAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(adtype) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, λ) - else - ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size)) - end - ForwardDiff.gradient!(out, f, λ, config) -end - -function grad!( - f::Function, - ::Type{<:TrackerAD}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - λ_tracked = Tracker.param(λ) - y = f(λ_tracked) - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(λ_tracked)) -end +include("grad.jl") # estimators abstract type AbstractVariationalObjective end @@ -170,6 +66,9 @@ abstract type AbstractEnergyEstimator end abstract type AbstractEntropyEstimator end abstract type AbstractControlVariate end +function init end +function update end + init(::Nothing) = nothing update(::Nothing, ::Nothing) = (nothing, nothing) @@ -178,11 +77,42 @@ include("objectives/elbo/advi.jl") include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") +export + ELBO, + ADVI, + ADVIEnergy, + ClosedFormEntropy, + MonteCarloEntropy + # Variational Families include("distributions/location_scale.jl") +export + VIFullRankGaussian, + VIMeanFieldGaussian + +""" + optimize(model, alg::VariationalInference) + optimize(model, alg::VariationalInference, q::VariationalPosterior) + optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) + +Constructs the variational posterior from the `model` and performs the optimization +following the configuration of the given `VariationalInference` instance. + +# Arguments +- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations +- `alg`: the VI algorithm used +- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. +- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` +- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior +""" +function optimize end + +include("optimize.jl") + +export optimize + include("utils.jl") -include("vi.jl") end # module diff --git a/src/ad.jl b/src/ad.jl deleted file mode 100644 index 62e785e1..00000000 --- a/src/ad.jl +++ /dev/null @@ -1,46 +0,0 @@ -############################## -# Global variables/constants # -############################## -const ADBACKEND = Ref(:forwarddiff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end -function setadbackend(::Val{:forwarddiff}) - ADBACKEND[] = :forwarddiff -end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) - ADBACKEND[] = :tracker -end - -const ADSAFE = Ref(false) -function setadsafe(switch::Bool) - @info("[AdvancedVI]: global ADSAFE is set as $switch") - ADSAFE[] = switch -end - -const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically - -function setchunksize(chunk_size::Int) - @info("[AdvancedVI]: AD chunk size is set as $chunk_size") - CHUNKSIZE[] = chunk_size -end - -abstract type ADBackend end -struct ForwardDiffAD{chunk} <: ADBackend end -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk - -struct TrackerAD <: ADBackend end - -ADBackend() = ADBackend(ADBACKEND[]) -ADBackend(T::Symbol) = ADBackend(Val(T)) - -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl index c6bb9ac3..cab50862 100644 --- a/src/compat/enzyme.jl +++ b/src/compat/enzyme.jl @@ -1,5 +1,16 @@ -struct EnzymeAD <: ADBackend end -ADBackend(::Val{:enzyme}) = EnzymeAD -function setadbackend(::Val{:enzyme}) - ADBACKEND[] = :enzyme + +function AdvancedVI.grad!( + f::Function, + ::AutoEnzyme, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + # Use `Enzyme.ReverseWithPrimal` once it is released: + # https://github.com/EnzymeAD/Enzyme.jl/pull/598 + y = f(λ) + DiffResults.value!(out, y) + dy = DiffResults.gradient(out) + fill!(dy, 0) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) + return out end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 721d0361..4d8f87d8 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,16 +1,19 @@ using .ReverseDiff: compile, GradientTape using .ReverseDiff.DiffResults: GradientResult -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setcache(b::Bool) = RDCache[] = b -getcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()} -function setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - tape(f, x) = GradientTape(f, x) function taperesult(f, x) return tape(f, x), GradientResult(x) end + +# Precompiled tapes are not properly supported yet. +function AdvancedVI.grad!( + f::Function, + ::AutoReverseDiff, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + tp = tape(f, λ) + ReverseDiff.gradient!(out, tp, λ) + return out +end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 40022e21..f1a29b87 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -1,5 +1,13 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote + +function AdvancedVI.grad!( + f::Function, + ::AutoZygote, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, + ) + y, back = Zygote.pullback(f, λ) + dy = first(back(1.0)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, dy) + return out end diff --git a/test/ad.jl b/test/ad.jl index c084165c..6b587598 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,17 +1,21 @@ using ReTest using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote -using AdvancedVI: grad! +using ADTypes @testset "ad" begin - @testset "$(string(adsymbol))" for adsymbol ∈ [ - :forwarddiff, :reversediff, :tracker, :enzyme, :zygote] + @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + :ForwardDiffAuto => AutoForwardDiff(), + :ForwardDiff => AutoForwardDiff(10), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Tracker => AutoTracker(), + ) D = 10 A = randn(D, D) λ = randn(D) - AdvancedVI.setadbackend(adsymbol) grad_buf = DiffResults.GradientResult(λ) - AdvancedVI.grad!(AdvancedVI.ADBackend(), λ, grad_buf) do λ′ + AdvancedVI.grad!(adsymbol, λ, grad_buf) do λ′ λ′'*A*λ′ / 2 end ∇ = DiffResults.gradient(grad_buf) From 19c62c888fafbed9271e66cf1c7ced7b11a90457 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:11:12 +0100 Subject: [PATCH 040/144] rename vi.jl to optimize.jl --- src/{vi.jl => optimize.jl} | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) rename src/{vi.jl => optimize.jl} (89%) diff --git a/src/vi.jl b/src/optimize.jl similarity index 89% rename from src/vi.jl rename to src/optimize.jl index 842f187e..07184900 100644 --- a/src/vi.jl +++ b/src/optimize.jl @@ -19,6 +19,7 @@ function optimize( progress ::Bool = true, callback! = nothing, terminate = (args...) -> false, + adback::AbstractADType = AutoForwardDiff(), ) opt_state = Optimisers.init(optimizer, λ) est_state = init(objective) @@ -33,7 +34,8 @@ function optimize( for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, est_state, stat′ = estimate_gradient(rng, objective, est_state, λ, restructure, grad_buf) + grad_buf, est_state, stat′ = estimate_gradient( + rng, adback, objective, est_state, λ, restructure, grad_buf) g = DiffResults.gradient(grad_buf) stat = merge(stat, stat′) @@ -51,6 +53,9 @@ function optimize( AdvancedVI.DEBUG && @debug "Step $t" stat... + q = project_domain(q) + λ, _ = Optimisers.destructure(q) + pm_next!(prog, stat) stats[t] = stat From 63da51de8870575971b8e70e28dfc6c2265c5e30 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:11:25 +0100 Subject: [PATCH 041/144] fix estimate_gradient to use adtypes --- src/objectives/elbo/advi.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 311a94f3..ed834273 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -41,6 +41,7 @@ end function estimate_gradient( rng::AbstractRNG, + adback::AbstractADType, advi::ADVI, est_state, λ::Vector{<:Real}, @@ -50,7 +51,7 @@ function estimate_gradient( # Gradient-stopping for computing the sticking-the-landing control variate q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing - grad!(ADBackend(), λ, out) do λ′ + grad!(adback, λ, out) do λ′ q_η = restructure(λ′) q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η -advi(q_η; rng, q_η_entropy) From 65ab47395fa4fe88b6b65323325c68b5c0ee078a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jul 2023 23:17:20 +0100 Subject: [PATCH 042/144] add exact inference tests --- test/distributions.jl | 5 +-- test/exact.jl | 64 +++++++++++++++++++++++++++++++++++ test/exact/normallognormal.jl | 52 ++++++++++++++++++++++++++++ test/runtests.jl | 13 +++---- 4 files changed, 124 insertions(+), 10 deletions(-) create mode 100644 test/exact.jl create mode 100644 test/exact/normallognormal.jl diff --git a/test/distributions.jl b/test/distributions.jl index 07b3efdf..074cad7c 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,9 +1,6 @@ using ReTest -using Distributions using Distributions: _logpdf -using LinearAlgebra -using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian @testset "distributions" begin @testset "$(string(covtype)) $(basedist) $(realtype)" for @@ -17,7 +14,7 @@ using AdvancedVI: LocationScale, VIFullRankGaussian, VIMeanFieldGaussian n_montecarlo = 1000_000 μ = randn(realtype, n_dims) - L₀ = randn(realtype, n_dims, n_dims) + L₀ = randn(realtype, n_dims, n_dims) |> LowerTriangular Σ = if covtype == :fullrank Σ = (L₀*L₀' + ϵ*I) |> Hermitian else diff --git a/test/exact.jl b/test/exact.jl new file mode 100644 index 00000000..27b92c04 --- /dev/null +++ b/test/exact.jl @@ -0,0 +1,64 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using ReTest +using Turing, LogDensityProblems +using Optimisers +using Distributions +using LinearAlgebra +using SimpleUnPack: @unpack + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("inference/normallognormal.jl") + +@testset "exact" begin + @testset "$(modelname) $(realtype)" for + realtype ∈ [Float32, Float64], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ) + + T = 10000 + modelstats = modelconstr(realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + prob = DynamicPPL.LogDensityFunction(model) + + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + ones(realtype, n_dims) |> Diagonal + else + diagm(ones(realtype, n_dims)) |> LowerTriangular + end + q₀ = if is_meanfield + AdvancedVI.VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) + else + AdvancedVI.VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) + end + + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + + objective = AdvancedVI.ADVI(prob, b⁻¹, 10) + q, stats = AdvancedVI.optimize( + objective, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/√T + end +end diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl new file mode 100644 index 00000000..4e9e1404 --- /dev/null +++ b/test/exact/normallognormal.jl @@ -0,0 +1,52 @@ + +function normallognormal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = π + μ_y = randn(rng, realtype, n_dims) + L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular + ϵ = realtype(1.0) + Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian + + Turing.@model function normallognormal() + x ~ LogNormal(μ_x, σ_x) + y ~ MvNormal(μ_y, Σ_y) + end + model = normallognormal() + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L |> LowerTriangular + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(realtype) + n_dims = 5 + + μ_x = randn(realtype) + σ_x = π + μ_y = randn(realtype, n_dims) + ϵ = realtype(1.0) + Σ_y = Diagonal(exp.(randn(realtype, n_dims))) + + Turing.@model function normallognormal() + x ~ LogNormal(μ_x, σ_x) + y ~ MvNormal(μ_y, Σ_y) + end + model = normallognormal() + + σ² = Vector{realtype}(undef, n_dims+1) + σ²[1] = σ_x^2 + σ²[2:end] = diag(Σ_y) + + μ = vcat(μ_x, μ_y) + L = sqrt.(σ²) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 44074197..26f9a06f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,13 @@ using ReTest: @testset, @test -#using Random -#using Statistics -#using Distributions, DistributionsAD - -println("Environment variables for testing") -println(ENV) +using Random +using Random: default_rng +using Statistics +using Distributions, DistributionsAD +using LinearAlgebra +using AdvancedVI include("ad.jl") include("distributions.jl") +include("exact.jl") From 3e5a4520835f0d182b8f7c4aaef0529ff37498e6 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 00:28:18 +0100 Subject: [PATCH 043/144] remove Turing dependency in tests --- test/exact.jl | 9 ++++--- test/exact/normallognormal.jl | 47 +++++++++++++++++++++++------------ test/runtests.jl | 9 ++++++- 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/test/exact.jl b/test/exact.jl index 27b92c04..d5283e8e 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -2,9 +2,11 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using ReTest -using Turing, LogDensityProblems +using Bijectors +using LogDensityProblems using Optimisers using Distributions +using PDMats using LinearAlgebra using SimpleUnPack: @unpack @@ -16,7 +18,7 @@ struct TestModel{M,L,S} is_meanfield::Bool end -include("inference/normallognormal.jl") +include("exact/normallognormal.jl") @testset "exact" begin @testset "$(modelname) $(realtype)" for @@ -32,7 +34,6 @@ include("inference/normallognormal.jl") b = Bijectors.bijector(model) b⁻¹ = inverse(b) - prob = DynamicPPL.LogDensityFunction(model) μ₀ = zeros(realtype, n_dims) L₀ = if is_meanfield @@ -48,7 +49,7 @@ include("inference/normallognormal.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - objective = AdvancedVI.ADVI(prob, b⁻¹, 10) + objective = AdvancedVI.ADVI(model, b⁻¹, 10) q, stats = AdvancedVI.optimize( objective, q₀, T; optimizer = Optimisers.AdaGrad(1e-1), diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl index 4e9e1404..e39ec2cb 100644 --- a/test/exact/normallognormal.jl +++ b/test/exact/normallognormal.jl @@ -1,4 +1,31 @@ +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + function normallognormal_fullrank(realtype; rng = default_rng()) n_dims = 5 @@ -9,11 +36,7 @@ function normallognormal_fullrank(realtype; rng = default_rng()) ϵ = realtype(1.0) Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian - Turing.@model function normallognormal() - x ~ LogNormal(μ_x, σ_x) - y ~ MvNormal(μ_y, Σ_y) - end - model = normallognormal() + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) Σ[1,1] = σ_x^2 @@ -33,20 +56,12 @@ function normallognormal_meanfield(realtype) σ_x = π μ_y = randn(realtype, n_dims) ϵ = realtype(1.0) - Σ_y = Diagonal(exp.(randn(realtype, n_dims))) - - Turing.@model function normallognormal() - x ~ LogNormal(μ_x, σ_x) - y ~ MvNormal(μ_y, Σ_y) - end - model = normallognormal() + σ_y = exp.(randn(realtype, n_dims)) - σ² = Vector{realtype}(undef, n_dims+1) - σ²[1] = σ_x^2 - σ²[2:end] = diag(Σ_y) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) μ = vcat(μ_x, μ_y) - L = sqrt.(σ²) |> Diagonal + L = vcat(σ_x, σ_y) |> Diagonal TestModel(model, μ, L, n_dims+1, true) end diff --git a/test/runtests.jl b/test/runtests.jl index 26f9a06f..0b86222b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,20 @@ +using Comonicon using ReTest: @testset, @test using Random using Random: default_rng using Statistics -using Distributions, DistributionsAD +using Distributions using LinearAlgebra using AdvancedVI +const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC") + include("ad.jl") include("distributions.jl") include("exact.jl") +@main function runtests(patterns...; dry::Bool = false) + retest(patterns...; dry = dry, verbose = Inf) +end + From 3117cec8952b80b58e205726f2abe9f77ffddf80 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 02:44:22 +0100 Subject: [PATCH 044/144] remove unused projection --- src/optimize.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 07184900..2acfbc0b 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -43,7 +43,6 @@ function optimize( Optimisers.subtract!(λ, Δλ) stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) stat = merge(stat, stat′) - q = restructure(λ) if !isnothing(callback!) @@ -53,9 +52,6 @@ function optimize( AdvancedVI.DEBUG && @debug "Step $t" stat... - q = project_domain(q) - λ, _ = Optimisers.destructure(q) - pm_next!(prog, stat) stats[t] = stat From b1ca9cf5cfad2345c92481c7519b12e1520776ef Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:03:57 +0100 Subject: [PATCH 045/144] remove redundant `ADVIEnergy` object (now baked into `ADVI`) --- src/AdvancedVI.jl | 2 +- src/objectives/elbo/advi.jl | 38 ++++++++++++++++++++---------- src/objectives/elbo/advi_energy.jl | 37 ----------------------------- 3 files changed, 26 insertions(+), 51 deletions(-) delete mode 100644 src/objectives/elbo/advi_energy.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 573f7179..502112c7 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -74,7 +74,6 @@ init(::Nothing) = nothing update(::Nothing, ::Nothing) = (nothing, nothing) include("objectives/elbo/advi.jl") -include("objectives/elbo/advi_energy.jl") include("objectives/elbo/entropy.jl") export @@ -82,6 +81,7 @@ export ADVI, ADVIEnergy, ClosedFormEntropy, + StickingTheLandingEntropy, MonteCarloEntropy # Variational Families diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index ed834273..9cd2433e 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,32 +1,41 @@ -struct ADVI{EnergyEst <: AbstractEnergyEstimator, +struct ADVI{Tlogπ, B, EntropyEst <: AbstractEntropyEstimator, ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective - energy_estimator::EnergyEst + ℓπ::Tlogπ + b⁻¹::B entropy_estimator::EntropyEst control_variate::ControlVar n_samples::Int + + function ADVI(prob, b⁻¹, entropy_estimator, control_variate, n_samples) + cap = LogDensityProblems.capabilities(prob) + if cap === nothing + throw( + ArgumentError( + "The log density function does not support the LogDensityProblems.jl interface", + ), + ) + end + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + new{typeof(ℓπ), typeof(b⁻¹), typeof(entropy_estimator), typeof(control_variate)}( + ℓπ, b⁻¹, entropy_estimator, control_variate, n_samples + ) + end end skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) init(advi::ADVI) = init(advi.control_variate) -Base.show(io::IO, advi::ADVI) = print( - io, - "ADVI(energy_estimator=$(advi.energy_estimator), " * - "entropy_estimator=$(advi.entropy_estimator), " * - "control_variate=$(advi.control_variate), " * - "n_samples=$(advi.n_samples))") - -function ADVI(energy_estimator::AbstractEnergyEstimator, +function ADVI(ℓπ, b⁻¹, entropy_estimator::AbstractEntropyEstimator, n_samples::Int) - ADVI(energy_estimator, entropy_estimator, nothing, n_samples) + ADVI(ℓπ, b⁻¹, entropy_estimator, nothing, n_samples) end function ADVI(ℓπ, b⁻¹, n_samples::Int) - ADVI(ADVIEnergy(ℓπ, b⁻¹), ClosedFormEntropy(), n_samples) + ADVI(ℓπ, b⁻¹, ClosedFormEntropy(), nothing, n_samples) end function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; @@ -34,7 +43,10 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; n_samples ::Int = advi.n_samples, ηs ::AbstractMatrix = rand(rng, q_η, n_samples), q_η_entropy::ContinuousMultivariateDistribution = q_η) - 𝔼ℓ = advi.energy_estimator(q_η, ηs) + 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b⁻¹, ηᵢ) + (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples + end ℍ = advi.entropy_estimator(q_η_entropy, ηs) 𝔼ℓ + ℍ end diff --git a/src/objectives/elbo/advi_energy.jl b/src/objectives/elbo/advi_energy.jl deleted file mode 100644 index 078a157e..00000000 --- a/src/objectives/elbo/advi_energy.jl +++ /dev/null @@ -1,37 +0,0 @@ - -struct ADVIEnergy{Tlogπ, B} <: AbstractEnergyEstimator - # Automatic differentiation variational inference - # - # Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). - # Automatic differentiation variational inference. - # Journal of machine learning research. - - ℓπ::Tlogπ - b⁻¹::B - - function ADVIEnergy(prob, b⁻¹) - # Could check whether the support of b⁻¹ and ℓπ match - cap = LogDensityProblems.capabilities(prob) - if cap === nothing - throw( - ArgumentError( - "The log density function does not support the LogDensityProblems.jl interface", - ), - ) - end - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b⁻¹)}(ℓπ, b⁻¹) - end -end - -ADVIEnergy(prob) = ADVIEnergy(prob, identity) - -Base.show(io::IO, energy::ADVIEnergy) = print(io, "ADVIEnergy()") - -function (energy::ADVIEnergy)(q, ηs::AbstractMatrix) - n_samples = size(ηs, 2) - mapreduce(+, eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(energy.b⁻¹, ηᵢ) - (energy.ℓπ(zᵢ) + logdetjacᵢ) / n_samples - end -end From fcbb729378e3e4e16e6288a9336511f2b616b557 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:04:21 +0100 Subject: [PATCH 046/144] add more tests, fix rng seed for tests --- test/exact.jl | 69 +++++++++++++++++++++++++++-------- test/exact/normallognormal.jl | 15 ++++---- test/runtests.jl | 2 +- 3 files changed, 61 insertions(+), 25 deletions(-) diff --git a/test/exact.jl b/test/exact.jl index d5283e8e..637a95ed 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -21,15 +21,22 @@ end include("exact/normallognormal.jl") @testset "exact" begin - @testset "$(modelname) $(realtype)" for + @testset "$(modelname) $(objname) $(realtype)" for realtype ∈ [Float32, Float64], (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), + :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), ) - + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + T = 10000 - modelstats = modelconstr(realtype) + modelstats = modelconstr(realtype; rng) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats b = Bijectors.bijector(model) @@ -42,24 +49,54 @@ include("exact/normallognormal.jl") diagm(ones(realtype, n_dims)) |> LowerTriangular end q₀ = if is_meanfield - AdvancedVI.VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) + VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) else - AdvancedVI.VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) + VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) end - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + obj = objective(model, b⁻¹, 10) - objective = AdvancedVI.ADVI(model, b⁻¹, 10) - q, stats = AdvancedVI.optimize( - objective, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - ) + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-0), + progress = PROGRESS, + rng = rng, + ) - μ = q.location - L = q.scale - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/√T + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end - @test Δλ ≤ Δλ₀/√T + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-2), + progress = PROGRESS, + rng = rng, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-2), + progress = PROGRESS, + rng = rng_repl, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end end end + diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl index e39ec2cb..7c5c000d 100644 --- a/test/exact/normallognormal.jl +++ b/test/exact/normallognormal.jl @@ -30,10 +30,10 @@ function normallognormal_fullrank(realtype; rng = default_rng()) n_dims = 5 μ_x = randn(rng, realtype) - σ_x = π + σ_x = ℯ μ_y = randn(rng, realtype, n_dims) L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(1.0) + ϵ = realtype(n_dims) Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) @@ -49,14 +49,13 @@ function normallognormal_fullrank(realtype; rng = default_rng()) TestModel(model, μ, L, n_dims+1, false) end -function normallognormal_meanfield(realtype) +function normallognormal_meanfield(realtype; rng = default_rng()) n_dims = 5 - μ_x = randn(realtype) - σ_x = π - μ_y = randn(realtype, n_dims) - ϵ = realtype(1.0) - σ_y = exp.(randn(realtype, n_dims)) + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) diff --git a/test/runtests.jl b/test/runtests.jl index 0b86222b..b571f8b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Comonicon using ReTest: @testset, @test using Random -using Random: default_rng +using Random123 using Statistics using Distributions using LinearAlgebra From 0f6f6a429ba74e491943ad96fa52ff9f897cc862 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:04:35 +0100 Subject: [PATCH 047/144] add more tests, fix seed for tests --- test/distributions.jl | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/test/distributions.jl b/test/distributions.jl index 074cad7c..073fff64 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -8,17 +8,19 @@ using Distributions: _logpdf covtype = [:meanfield, :fullrank], realtype = [Float32, Float64] + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) realtype = Float64 ϵ = 1e-2 n_dims = 10 n_montecarlo = 1000_000 - μ = randn(realtype, n_dims) - L₀ = randn(realtype, n_dims, n_dims) |> LowerTriangular + μ = randn(rng, realtype, n_dims) + L₀ = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular Σ = if covtype == :fullrank Σ = (L₀*L₀' + ϵ*I) |> Hermitian else - Diagonal(exp.(randn(realtype, n_dims))) + Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) end L = cholesky(Σ).L @@ -31,15 +33,26 @@ using Distributions: _logpdf MvNormal(μ, Σ) end - z = randn(n_dims) - @test logpdf(q, z) ≈ logpdf(q_true, z) - @test _logpdf(q, z) ≈ _logpdf(q_true, z) - @test entropy(q) ≈ entropy(q_true) + @testset "logpdf" begin + z = randn(rng, realtype, n_dims) + @test logpdf(q, z) ≈ logpdf(q_true, z) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) + @test eltype(logpdf(q, z)) == realtype + @test eltype(_logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end - z_samples = rand(q, n_montecarlo) - threesigma = L - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @testset "sampling" begin + z_samples = rand(rng, q, n_montecarlo) + threesigma = L + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end end end From f5f5863b55af07ea1009528e5b8e1fdb1bfc96df Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 03:16:49 +0100 Subject: [PATCH 048/144] fix non-determinism bug --- src/distributions/location_scale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index c1803ffe..e9e8c743 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -47,7 +47,7 @@ end function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) @unpack location, scale, dist = q n_dims = length(location) - scale*rand(dist, n_dims, num_samples) .+ location + scale*rand(rng, dist, n_dims, num_samples) .+ location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) From ade0d1007c1507fb0359d744fa640349314e325d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 04:56:29 +0100 Subject: [PATCH 049/144] fix test hyperparameters so that tests pass, minor cleanups --- src/distributions/location_scale.jl | 12 ++++++++++++ src/objectives/elbo/advi.jl | 6 ++++++ src/optimize.jl | 6 ++++-- test/exact.jl | 10 +++++----- test/exact/normallognormal.jl | 2 +- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index e9e8c743..dc9c1b27 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,4 +1,16 @@ +""" + +The [location scale] variational family broadly represents various variational +families using `location` and `scale` variational parameters. + +Multivariate Student-t variational family with ``\\nu``-degrees of freedom can +be constructed as: +```julia +q₀ = VILocationScale(μ, L, StudentT(ν), eps(Float32)) +``` + +""" struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution location::L scale ::S diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 9cd2433e..b9b1185f 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -24,6 +24,12 @@ struct ADVI{Tlogπ, B, end end +Base.show(io::IO, advi::ADVI) = + print(io, + "ADVI(entropy_estimator=$(advi.entropy_estimator), " * + "control_variate=$(advi.control_variate), " * + "n_samples=$(advi.n_samples))") + skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) init(advi::ADVI) = init(advi.control_variate) diff --git a/src/optimize.jl b/src/optimize.jl index 2acfbc0b..dcd1c439 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -41,9 +41,11 @@ function optimize( opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g) Optimisers.subtract!(λ, Δλ) + stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) stat = merge(stat, stat′) - q = restructure(λ) + + q = restructure(λ) if !isnothing(callback!) stat′ = callback!(q, stat) @@ -56,7 +58,7 @@ function optimize( stats[t] = stat # Termination decision is work in progress - if terminate(rng, q, objective, stat) + if terminate(rng, λ, q, objective, stat) stats = stats[1:t] break end diff --git a/test/exact.jl b/test/exact.jl index 637a95ed..d1be4626 100644 --- a/test/exact.jl +++ b/test/exact.jl @@ -49,9 +49,9 @@ include("exact/normallognormal.jl") diagm(ones(realtype, n_dims)) |> LowerTriangular end q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀, realtype(1e-8)) + VIMeanFieldGaussian(μ₀, L₀) else - VIFullRankGaussian(μ₀, L₀, realtype(1e-8)) + VIFullRankGaussian(μ₀, L₀) end obj = objective(model, b⁻¹, 10) @@ -60,7 +60,7 @@ include("exact/normallognormal.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-0), + optimizer = Optimisers.AdaGrad(1e-1), progress = PROGRESS, rng = rng, ) @@ -78,7 +78,7 @@ include("exact/normallognormal.jl") rng = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-2), + optimizer = Optimisers.AdaGrad(1e-1), progress = PROGRESS, rng = rng, ) @@ -88,7 +88,7 @@ include("exact/normallognormal.jl") rng_repl = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-2), + optimizer = Optimisers.AdaGrad(1e-1), progress = PROGRESS, rng = rng_repl, ) diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl index 7c5c000d..18e8b4a3 100644 --- a/test/exact/normallognormal.jl +++ b/test/exact/normallognormal.jl @@ -33,7 +33,7 @@ function normallognormal_fullrank(realtype; rng = default_rng()) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(n_dims) + ϵ = realtype(n_dims*2) Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) From 0caf7a9ef768ce97c7498c981d5ef60ee673488f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 19:37:45 +0100 Subject: [PATCH 050/144] fix minor reorganization --- src/AdvancedVI.jl | 9 +-- test/exact.jl | 102 ---------------------------------- test/exact/normallognormal.jl | 66 ---------------------- test/runtests.jl | 4 +- 4 files changed, 4 insertions(+), 177 deletions(-) delete mode 100644 test/exact.jl delete mode 100644 test/exact/normallognormal.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 502112c7..86c9fc44 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,7 @@ module AdvancedVI -using SimpleUnPack: @unpack +using SimpleUnPack: @unpack, @pack! using Accessors import Random: AbstractRNG, default_rng @@ -60,17 +60,14 @@ include("grad.jl") # estimators abstract type AbstractVariationalObjective end +function init end function estimate_gradient end -abstract type AbstractEnergyEstimator end +# ADVI-specific interfaces abstract type AbstractEntropyEstimator end abstract type AbstractControlVariate end -function init end function update end - -init(::Nothing) = nothing - update(::Nothing, ::Nothing) = (nothing, nothing) include("objectives/elbo/advi.jl") diff --git a/test/exact.jl b/test/exact.jl deleted file mode 100644 index d1be4626..00000000 --- a/test/exact.jl +++ /dev/null @@ -1,102 +0,0 @@ - -const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false - -using ReTest -using Bijectors -using LogDensityProblems -using Optimisers -using Distributions -using PDMats -using LinearAlgebra -using SimpleUnPack: @unpack - -struct TestModel{M,L,S} - model::M - μ_true::L - L_true::S - n_dims::Int - is_meanfield::Bool -end - -include("exact/normallognormal.jl") - -@testset "exact" begin - @testset "$(modelname) $(objname) $(realtype)" for - realtype ∈ [Float32, Float64], - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - :NormalLogNormalFullRank => normallognormal_fullrank, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), - :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), - ) - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - T = 10000 - modelstats = modelconstr(realtype; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) - - μ₀ = zeros(realtype, n_dims) - L₀ = if is_meanfield - ones(realtype, n_dims) |> Diagonal - else - diagm(ones(realtype, n_dims)) |> LowerTriangular - end - q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀) - else - VIFullRankGaussian(μ₀, L₀) - end - - obj = objective(model, b⁻¹, 10) - - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) - - μ = q.location - L = q.scale - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ Δλ₀/√T - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) - μ = q.location - L = q.scale - - rng_repl = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng_repl, - ) - μ_repl = q.location - L_repl = q.scale - @test μ == μ_repl - @test L == L_repl - end - end -end - diff --git a/test/exact/normallognormal.jl b/test/exact/normallognormal.jl deleted file mode 100644 index 18e8b4a3..00000000 --- a/test/exact/normallognormal.jl +++ /dev/null @@ -1,66 +0,0 @@ - -struct NormalLogNormal{MX,SX,MY,SY} - μ_x::MX - σ_x::SX - μ_y::MY - Σ_y::SY -end - -function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model - logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) -end - -function LogDensityProblems.dimension(model::NormalLogNormal) - length(model.μ_y) + 1 -end - -function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model - Bijectors.Stacked( - Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), - [1:1, 2:1+length(μ_y)]) -end - -function normallognormal_fullrank(realtype; rng = default_rng()) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(n_dims*2) - Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) - - Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) - Σ[1,1] = σ_x^2 - Σ[2:end,2:end] = Σ_y - Σ = Σ |> Hermitian - - μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L |> LowerTriangular - - TestModel(model, μ, L, n_dims+1, false) -end - -function normallognormal_meanfield(realtype; rng = default_rng()) - n_dims = 5 - - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) - - μ = vcat(μ_x, μ_y) - L = vcat(σ_x, σ_y) |> Diagonal - - TestModel(model, μ, L, n_dims+1, true) -end diff --git a/test/runtests.jl b/test/runtests.jl index b571f8b8..ddc1d09c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,11 +8,9 @@ using Distributions using LinearAlgebra using AdvancedVI -const GROUP = get(ENV, "AHMC_TEST_GROUP", "AdvancedHMC") - include("ad.jl") include("distributions.jl") -include("exact.jl") +include("advi_locscale.jl") @main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) From 5658cbf10e3f6e64d7b03380d4c026951cb3f0c2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 19:40:59 +0100 Subject: [PATCH 051/144] add missing files --- test/Project.toml | 20 +++++++ test/advi_locscale.jl | 102 +++++++++++++++++++++++++++++++++ test/models/normallognormal.jl | 66 +++++++++++++++++++++ 3 files changed, 188 insertions(+) create mode 100644 test/Project.toml create mode 100644 test/advi_locscale.jl create mode 100644 test/models/normallognormal.jl diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..2f38c88f --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,20 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" +ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl new file mode 100644 index 00000000..2beb0547 --- /dev/null +++ b/test/advi_locscale.jl @@ -0,0 +1,102 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using ReTest +using Bijectors +using LogDensityProblems +using Optimisers +using Distributions +using PDMats +using LinearAlgebra +using SimpleUnPack: @unpack + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("models/normallognormal.jl") + +@testset "exact" begin + @testset "$(modelname) $(objname) $(realtype)" for + realtype ∈ [Float32, Float64], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), + :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), + ) + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + T = 10000 + modelstats = modelconstr(realtype; rng) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + ones(realtype, n_dims) |> Diagonal + else + diagm(ones(realtype, n_dims)) |> LowerTriangular + end + q₀ = if is_meanfield + VIMeanFieldGaussian(μ₀, L₀) + else + VIFullRankGaussian(μ₀, L₀) + end + + obj = objective(model, b⁻¹, 10) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/√T + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng_repl, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + end +end + diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl new file mode 100644 index 00000000..18e8b4a3 --- /dev/null +++ b/test/models/normallognormal.jl @@ -0,0 +1,66 @@ + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +function normallognormal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular + ϵ = realtype(n_dims*2) + Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L |> LowerTriangular + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(realtype; rng = default_rng()) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + + μ = vcat(μ_x, μ_y) + L = vcat(σ_x, σ_y) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end From c712a9762afdbc60468953bfeab1ad076a6cc2f9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 19:51:14 +0100 Subject: [PATCH 052/144] fix add missing file, rename adbackend argument --- src/grad.jl | 30 ++++++++++++++++++++++++++++++ src/optimize.jl | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 src/grad.jl diff --git a/src/grad.jl b/src/grad.jl new file mode 100644 index 00000000..e68e1623 --- /dev/null +++ b/src/grad.jl @@ -0,0 +1,30 @@ + +# default implementations +function grad!( + f::Function, + adtype::AutoForwardDiff{chunksize}, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) where {chunksize} + # Set chunk size and do ForwardMode. + config = if isnothing(chunksize) + ForwardDiff.GradientConfig(f, λ) + else + ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunksize)) + end + ForwardDiff.gradient!(out, f, λ, config) +end + +function grad!( + f::Function, + ::AutoTracker, + λ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) + λ_tracked = Tracker.param(λ) + y = f(λ_tracked) + Tracker.back!(y, 1.0) + + DiffResults.value!(out, Tracker.data(y)) + DiffResults.gradient!(out, Tracker.grad(λ_tracked)) +end diff --git a/src/optimize.jl b/src/optimize.jl index dcd1c439..16995925 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -19,7 +19,7 @@ function optimize( progress ::Bool = true, callback! = nothing, terminate = (args...) -> false, - adback::AbstractADType = AutoForwardDiff(), + adbackend::AbstractADType = AutoForwardDiff(), ) opt_state = Optimisers.init(optimizer, λ) est_state = init(objective) From bee839d91399ce9cc2d776f907dd9197e14aa241 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 20:03:16 +0100 Subject: [PATCH 053/144] fix errors --- src/AdvancedVI.jl | 2 ++ src/objectives/elbo/advi.jl | 4 ++-- src/optimize.jl | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 86c9fc44..4010b1fe 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -63,6 +63,8 @@ abstract type AbstractVariationalObjective end function init end function estimate_gradient end +init(::Nothing) = nothing + # ADVI-specific interfaces abstract type AbstractEntropyEstimator end abstract type AbstractControlVariate end diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index b9b1185f..1fb6b0c6 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -59,7 +59,7 @@ end function estimate_gradient( rng::AbstractRNG, - adback::AbstractADType, + adbackend::AbstractADType, advi::ADVI, est_state, λ::Vector{<:Real}, @@ -69,7 +69,7 @@ function estimate_gradient( # Gradient-stopping for computing the sticking-the-landing control variate q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing - grad!(adback, λ, out) do λ′ + grad!(adbackend, λ, out) do λ′ q_η = restructure(λ′) q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η -advi(q_η; rng, q_η_entropy) diff --git a/src/optimize.jl b/src/optimize.jl index 16995925..8b36df04 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,7 +35,7 @@ function optimize( stat = (iteration=t,) grad_buf, est_state, stat′ = estimate_gradient( - rng, adback, objective, est_state, λ, restructure, grad_buf) + rng, adbackend, objective, est_state, λ, restructure, grad_buf) g = DiffResults.gradient(grad_buf) stat = merge(stat, stat′) From 913911ec74f835d566e2f19b0df16358a3fd055b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 14 Jul 2023 20:03:23 +0100 Subject: [PATCH 054/144] rename test suite --- test/advi_locscale.jl | 149 +++++++++++++++++++++++------------------- 1 file changed, 80 insertions(+), 69 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 2beb0547..342b9db1 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -20,83 +20,94 @@ end include("models/normallognormal.jl") -@testset "exact" begin - @testset "$(modelname) $(objname) $(realtype)" for - realtype ∈ [Float32, Float64], - (modelname, modelconstr) ∈ Dict( - :NormalLogNormalMeanField => normallognormal_meanfield, - :NormalLogNormalFullRank => normallognormal_fullrank, - ), - (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), - :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), - ) - seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) - rng = Philox4x(UInt64, seed, 8) - - T = 10000 - modelstats = modelconstr(realtype; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats +@testset "advi" begin + @testset "locscale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float32, Float64], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + :NormalLogNormalFullRank => normallognormal_fullrank, + ), + (objname, objective) ∈ Dict( + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), + :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), + ) - b = Bijectors.bijector(model) - b⁻¹ = inverse(b) + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) - μ₀ = zeros(realtype, n_dims) - L₀ = if is_meanfield - ones(realtype, n_dims) |> Diagonal - else - diagm(ones(realtype, n_dims)) |> LowerTriangular - end - q₀ = if is_meanfield - VIMeanFieldGaussian(μ₀, L₀) - else - VIFullRankGaussian(μ₀, L₀) - end + T = 10000 + modelstats = modelconstr(realtype; rng) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - obj = objective(model, b⁻¹, 10) + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) - @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) + μ₀ = zeros(realtype, n_dims) + L₀ = if is_meanfield + ones(realtype, n_dims) |> Diagonal + else + diagm(ones(realtype, n_dims)) |> LowerTriangular + end + q₀ = if is_meanfield + VIMeanFieldGaussian(μ₀, L₀) + else + VIFullRankGaussian(μ₀, L₀) + end - μ = q.location - L = q.scale - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + obj = objective(model, b⁻¹, 10) - @test Δλ ≤ Δλ₀/√T - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + adbackend = adbackend, + ) - @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng, - ) - μ = q.location - L = q.scale + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - rng_repl = Philox4x(UInt64, seed, 8) - q, stats = optimize( - obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), - progress = PROGRESS, - rng = rng_repl, - ) - μ_repl = q.location - L_repl = q.scale - @test μ == μ_repl - @test L == L_repl + @test Δλ ≤ Δλ₀/√T + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng, + adbackend = adbackend, + ) + μ = q.location + L = q.scale + + rng_repl = Philox4x(UInt64, seed, 8) + q, stats = optimize( + obj, q₀, T; + optimizer = Optimisers.AdaGrad(1e-1), + progress = PROGRESS, + rng = rng_repl, + adbackend = adbackend, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end end end end - From d50cabb0f0b7b7fac8bfd79c43ef38196b2df8c9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 01:59:43 +0100 Subject: [PATCH 055/144] refactor renamed arguments for ADVI to be shorter --- Project.toml | 3 +- src/AdvancedVI.jl | 7 ++-- src/objectives/elbo/advi.jl | 59 +++++++++++++++++----------------- src/objectives/elbo/entropy.jl | 42 ++++++++++++++---------- test/ad.jl | 10 +++--- test/advi_locscale.jl | 18 +++++------ 6 files changed, 73 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index 2fcc845e..cf698f7a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -25,9 +24,9 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] +ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10.3" ProgressMeter = "1.0.0" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 4010b1fe..e3dd85a8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -5,7 +5,10 @@ using SimpleUnPack: @unpack, @pack! using Accessors import Random: AbstractRNG, default_rng -import Distributions: logpdf, _logpdf, rand, _rand!, _rand! +using Distributions +import Distributions: + logpdf, _logpdf, rand, _rand!, _rand!, + ContinuousMultivariateDistribution using Functors using Optimisers @@ -24,8 +27,6 @@ using ForwardDiff, Tracker using FillArrays using PDMats -using Distributions, DistributionsAD -using Distributions: ContinuousMultivariateDistribution using Bijectors using StatsBase diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 1fb6b0c6..e965ea73 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,14 +1,28 @@ +""" + ADVI + +Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- ``\\pi`` must be differentiable + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. +""" struct ADVI{Tlogπ, B, - EntropyEst <: AbstractEntropyEstimator, - ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective + EntropyEst <: AbstractEntropyEstimator, + ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective ℓπ::Tlogπ - b⁻¹::B - entropy_estimator::EntropyEst - control_variate::ControlVar + b::B + entropy::EntropyEst + cv::ControlVar n_samples::Int - function ADVI(prob, b⁻¹, entropy_estimator, control_variate, n_samples) + function ADVI(prob, n_samples::Int; + entropy::AbstractEntropyEstimator = ClosedFormEntropy(), + cv::Union{<:AbstractControlVariate, Nothing} = nothing, + b = Bijectors.identity) cap = LogDensityProblems.capabilities(prob) if cap === nothing throw( @@ -18,31 +32,16 @@ struct ADVI{Tlogπ, B, ) end ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b⁻¹), typeof(entropy_estimator), typeof(control_variate)}( - ℓπ, b⁻¹, entropy_estimator, control_variate, n_samples - ) + new{typeof(ℓπ), typeof(b), typeof(entropy), typeof(cv)}(ℓπ, b, entropy, cv, n_samples) end end Base.show(io::IO, advi::ADVI) = - print(io, - "ADVI(entropy_estimator=$(advi.entropy_estimator), " * - "control_variate=$(advi.control_variate), " * - "n_samples=$(advi.n_samples))") - -skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy_estimator) + print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))") -init(advi::ADVI) = init(advi.control_variate) +skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy) -function ADVI(ℓπ, b⁻¹, - entropy_estimator::AbstractEntropyEstimator, - n_samples::Int) - ADVI(ℓπ, b⁻¹, entropy_estimator, nothing, n_samples) -end - -function ADVI(ℓπ, b⁻¹, n_samples::Int) - ADVI(ℓπ, b⁻¹, ClosedFormEntropy(), nothing, n_samples) -end +init(advi::ADVI) = init(advi.cv) function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; rng ::AbstractRNG = default_rng(), @@ -50,10 +49,10 @@ function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; ηs ::AbstractMatrix = rand(rng, q_η, n_samples), q_η_entropy::ContinuousMultivariateDistribution = q_η) 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b⁻¹, ηᵢ) + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples end - ℍ = advi.entropy_estimator(q_η_entropy, ηs) + ℍ = advi.entropy(q_η_entropy, ηs) 𝔼ℓ + ℍ end @@ -67,17 +66,17 @@ function estimate_gradient( out::DiffResults.MutableDiffResult) # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(advi.entropy_estimator) ? restructure(λ) : nothing + q_η_stop = skip_entropy_gradient(advi.entropy) ? restructure(λ) : nothing grad!(adbackend, λ, out) do λ′ q_η = restructure(λ′) - q_η_entropy = skip_entropy_gradient(advi.entropy_estimator) ? q_η_stop : q_η + q_η_entropy = skip_entropy_gradient(advi.entropy) ? q_η_stop : q_η -advi(q_η; rng, q_η_entropy) end nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - est_state, stat′ = update(advi.control_variate, est_state) + est_state, stat′ = update(advi.cv, est_state) stat = !isnothing(stat′) ? merge(stat′, stat) : stat out, est_state, stat diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index ddeb64a9..994bdd4f 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -14,27 +14,35 @@ MonteCarloEntropy() = MonteCarloEntropy{false}() Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()") """ - Sticking the Landing Control Variate + StickingTheLandingEntropy() - # Explanation +# Explanation - This eatimator forms a control variate of the form of +The STL estimator forms a control variate of the form of - c(z) = 𝔼-logq(z) + logq(z) = ℍ[q] - logq(z) +```math +\\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) = + \\mathbb{E}\\left[ -\\log q\\left(z\\right) \\right] + + \\log q\\left(z\\right) = \\mathbb{H}\\left(q_{\\lambda}\\right) + \\log q_{\\lambda}\\left(z\\right), +``` +where, for the score term, the gradient is stopped from propagating. - Adding this to the closed-form entropy ELBO estimator yields: - - ELBO - c(z) = 𝔼logπ(z) + ℍ[q] - c(z) = 𝔼logπ(z) - logq(z), - - which has the same expectation, but lower variance when π ≈ q, - and higher variance when π ≉ q. - - # Reference - - Roeder, Geoffrey, Yuhuai Wu, and David K. Duvenaud. - "Sticking the landing: Simple, lower-variance gradient estimators for - variational inference." - Advances in Neural Information Processing Systems 30 (2017). +Adding this to the closed-form entropy ELBO estimator yields the STL estimator: +```math +\\begin{aligned} + \\widehat{\\mathrm{ELBO}}_{\\mathrm{STL}}\\left(\\lambda\\right) + &\\triangleq \\mathbb{E}\\left[ \\log \\pi \\left(z\\right) \\right] - \\log q_{\\lambda} \\left(z\\right) \\\\ + &= \\mathbb{E}\\left[ \\log \\pi\\left(z\\right) \\right] + + \\mathbb{H}\\left(q_{\\lambda}\\right) - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) \\\\ + &= \\widehat{\\mathrm{ELBO}}\\left(\\lambda\\right) + - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right), +\\end{aligned} +``` +which has the same expectation, but lower variance when ``\\pi \\approx q_{\\lambda}``, +and higher variance when ``\\pi \\not\\approx q_{\\lambda}``. + +# Reference +1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. """ StickingTheLandingEntropy() = MonteCarloEntropy{true}() diff --git a/test/ad.jl b/test/ad.jl index 6b587598..1efa536b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -5,11 +5,11 @@ using ADTypes @testset "ad" begin @testset "$(adname)" for (adname, adsymbol) ∈ Dict( - :ForwardDiffAuto => AutoForwardDiff(), - :ForwardDiff => AutoForwardDiff(10), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Tracker => AutoTracker(), + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Tracker => AutoTracker(), + :Enzyme => AutoEnzyme(), ) D = 10 A = randn(D, D) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 342b9db1..dadbaf25 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -29,15 +29,15 @@ include("models/normallognormal.jl") :NormalLogNormalFullRank => normallognormal_fullrank, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, b⁻¹, M), - :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, b⁻¹, StickingTheLandingEntropy(), M), - :ADVIFullMonteCarlo => (model, b⁻¹, M) -> ADVI(model, b⁻¹, MonteCarloEntropy(), M), + :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), + :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, H = StickingTheLandingEntropy()), + :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, H = MonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), + # :ReverseDiff => AutoReverseDiff(), + # :Zygote => AutoZygote(), + # :Enzyme => AutoEnzyme(), ) seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) @@ -68,7 +68,7 @@ include("models/normallognormal.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), + optimizer = Optimisers.Adam(1e-3), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -87,7 +87,7 @@ include("models/normallognormal.jl") rng = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), + optimizer = Optimisers.Adam(1e-3), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -98,7 +98,7 @@ include("models/normallognormal.jl") rng_repl = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.AdaGrad(1e-1), + optimizer = Optimisers.Adam(1e-3), progress = PROGRESS, rng = rng_repl, adbackend = adbackend, From b134f7099062b2c7a6d7b3ec9e30867703c609da Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 02:07:08 +0100 Subject: [PATCH 056/144] fix compile error in advi test --- test/advi_locscale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index dadbaf25..40e5dace 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -30,8 +30,8 @@ include("models/normallognormal.jl") ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), - :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, H = StickingTheLandingEntropy()), - :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, H = MonteCarloEntropy()), + :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), + :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = MonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), From a6ba379b9a97e509076ce0c7e2c2ebd4b6caa737 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 22:33:52 +0100 Subject: [PATCH 057/144] add initial doc --- docs/make.jl | 17 +++++++++++ docs/src/advi.md | 67 ++++++++++++++++++++++++++++++++++++++++++++ docs/src/families.md | 58 ++++++++++++++++++++++++++++++++++++++ docs/src/index.md | 14 +++++++++ 4 files changed, 156 insertions(+) create mode 100644 docs/make.jl create mode 100644 docs/src/advi.md create mode 100644 docs/src/families.md create mode 100644 docs/src/index.md diff --git a/docs/make.jl b/docs/make.jl new file mode 100644 index 00000000..d2a01d1b --- /dev/null +++ b/docs/make.jl @@ -0,0 +1,17 @@ +#using AdvancedVI +using Documenter + +DocMeta.setdocmeta!( + AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true +) + +makedocs(; + sitename = "AdvancedVI.jl", + modules = [AdvancedVI], + format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), + pages = ["index.md", + "families.md", + "advi.md"], +) + +deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", devbranch="main") diff --git a/docs/src/advi.md b/docs/src/advi.md new file mode 100644 index 00000000..4f4a2eca --- /dev/null +++ b/docs/src/advi.md @@ -0,0 +1,67 @@ + +# [Automatic Differentiation Variational Inference](@id advi) +The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. +By maximizing ADVI objective, it is equivalent to solving the problem + +```math + \mathrm{minimize}_{\lambda \in \Lambda}\quad \mathrm{KL}\left(q_{\phi,\lambda}, \pi\right). +``` + +The key aspects of the ADVI objective are the followings: +1. The use of the reparameterization gradient estimator +2. Automatically match the support of the target posterior through "bijectors." + +Thanks to Item 2, the user is free to choose any unconstrained variational family, for which +bijectors will automatically match the potentially constrained support of the target. + +In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}`` +from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that +```math +z &\sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad +z &\stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} +``` +ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``. + +That is, + +```math +\begin{aligned} +\mathrm{ADVI}\left(\lambda\right) +&\triangleq +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + \log \pi\left( \phi^{-1}\left( \eta \right) \right) +\right] ++ \mathbb{H}\left(q_{\lambda}\right) ++ \log \lvert J_{\phi^{-1}}\left(\eta\right) \rvert \\ +&= +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + \log \pi\left( \phi^{-1}\left( \eta \right) \right) +\right] ++ +\mathbb{E}_{\eta \sim q_{\lambda}}\left[ + - \log q_{\lambda}\left( \eta \right) \lvert J_{\phi}\left(\eta\right) \rvert +\right] \\ +&= +\mathbb{E}_{z \sim q_{\phi,\lambda}}\left[ \log \pi\left(z\right) \right] ++ +\mathbb{H}\left(q_{\phi,\lambda}\right) +\end{aligned} +``` + +The idea of using the reparameterization gradient estimator for variational inference was first +coined by Titsias and Lázaro-Gredilla (2014). +Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by +Fjelde *et al.* (2017). + + +```@docs +ADVI +``` + +# References +1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. +3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. +4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. + + diff --git a/docs/src/families.md b/docs/src/families.md new file mode 100644 index 00000000..f203cf18 --- /dev/null +++ b/docs/src/families.md @@ -0,0 +1,58 @@ + +# [Variational Families](@id families) + +## Location-Scale Variational Family + +### Description +The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as +```math +z = C u + m, +``` +where ``C`` is the *scale* and ``m`` is the location variational parameter. +This family encompases many + + +### Constructors + +```@docs +VILocationScale +``` + +```@docs +VIFullRankGaussian +VIMeanFieldGaussian +``` + +### Examples + +A full-rank variational family can be formed by choosing +```@repl locscale +using AdvancedVI, LinearAlgebra +μ = zeros(2); +L = diagm(ones(2)) |> LowerTriangular; +``` + +A mean-field variational family can be formed by choosing +```@repl locscale +μ = zeros(2); +L = ones(2) |> Diagonal; +``` + +Gaussian variational family: +```@repl locscale +q = VIFullRankGaussian(μ, L) +q = VIMeanFieldGaussian(μ, L) +``` + +Sudent-T Variational Family: + +```@repl locscale +ν = 3 +q = VILocationScale(μ, L, StudentT(ν)) +``` + +Multivariate Laplace family: +```@repl locscale +q = VILocationScale(μ, L, Laplace()) +``` + diff --git a/docs/src/index.md b/docs/src/index.md new file mode 100644 index 00000000..be326921 --- /dev/null +++ b/docs/src/index.md @@ -0,0 +1,14 @@ +```@meta +CurrentModule = AdvancedVI +``` + +# AdvancedVI + +Documentation for [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl). + +```@index +``` + +```@autodocs +Modules = [AdvancedVI] +``` From 619b1c05eaf669491f82406becb9a31dba1871cc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 22:34:32 +0100 Subject: [PATCH 058/144] remove unused epsilon argument in location scale --- src/distributions/location_scale.jl | 40 +++++++++++++++++------------ 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index dc9c1b27..5eb371ad 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,31 +1,31 @@ """ + VILocationScale{L,R,D}(location::L, scale::S, dist::D) <: ContinuousMultivariateDistribution The [location scale] variational family broadly represents various variational families using `location` and `scale` variational parameters. -Multivariate Student-t variational family with ``\\nu``-degrees of freedom can -be constructed as: +It generally represents any distribution for which the sampling path can be +represented as the following: ```julia -q₀ = VILocationScale(μ, L, StudentT(ν), eps(Float32)) + d = length(location) + u = rand(dist, d) + z = scale*u + location ``` - """ -struct VILocationScale{L, S, D, R} <: ContinuousMultivariateDistribution +struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution location::L scale ::S dist ::D - epsilon ::R function VILocationScale(μ::AbstractVector{<:Real}, L::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, - q_base::ContinuousUnivariateDistribution, - epsilon::Real) + q_base::ContinuousUnivariateDistribution) # Restricting all the arguments to have the same types creates problems # with dual-variable-based AD frameworks. @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - new{typeof(μ), typeof(L), typeof(q_base), typeof(epsilon)}(μ, L, q_base, epsilon) + new{typeof(μ), typeof(L), typeof(q_base)}(μ, L, q_base) end end @@ -76,16 +76,22 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) return x += location end -function VIFullRankGaussian(μ::AbstractVector{T}, - L::AbstractTriangular{T}, - epsilon::Real = eps(T)) where {T <: Real} +""" + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + +This constructs a multivariate Gaussian distribution with a full rank covariance matrix. +""" +function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) where {T <: Real} q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base, epsilon) + VILocationScale(μ, L, q_base) end -function VIMeanFieldGaussian(μ::AbstractVector{T}, - L::Diagonal{T}, - epsilon::Real = eps(T)) where {T <: Real} +""" + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + +This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. +""" +function VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T <: Real} q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base, epsilon) + VILocationScale(μ, L, q_base) end From f1c02f02909ff15ac2ddc6276af8589c97cfedf8 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jul 2023 22:39:16 +0100 Subject: [PATCH 059/144] add project file for documenter --- docs/Project.toml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 docs/Project.toml diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 00000000..fc885857 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1,7 @@ +[deps] +AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" + +[compat] +Documenter = "0.26" \ No newline at end of file From b0f259a4c32ad293cf0edd236b42b132d7e959b5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 16 Jul 2023 02:55:03 +0100 Subject: [PATCH 060/144] refactor STL gradient calculation to use multiple dispatch --- src/AdvancedVI.jl | 6 +- src/distributions/location_scale.jl | 16 ++--- src/objectives/elbo/advi.jl | 97 +++++++++++++++++++++++------ src/objectives/elbo/entropy.jl | 11 +--- test/advi_locscale.jl | 6 +- test/models/normal.jl | 51 +++++++++++++++ test/models/normallognormal.jl | 4 +- test/models/utils.jl | 8 +++ 8 files changed, 160 insertions(+), 39 deletions(-) create mode 100644 test/models/normal.jl create mode 100644 test/models/utils.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e3dd85a8..9f93885c 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -73,8 +73,9 @@ abstract type AbstractControlVariate end function update end update(::Nothing, ::Nothing) = (nothing, nothing) -include("objectives/elbo/advi.jl") +# entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") +include("objectives/elbo/advi.jl") export ELBO, @@ -82,13 +83,14 @@ export ADVIEnergy, ClosedFormEntropy, StickingTheLandingEntropy, - MonteCarloEntropy + FullMonteCarloEntropy # Variational Families include("distributions/location_scale.jl") export + VILocationScale, VIFullRankGaussian, VIMeanFieldGaussian diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 5eb371ad..e901e8de 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -1,8 +1,8 @@ """ - VILocationScale{L,R,D}(location::L, scale::S, dist::D) <: ContinuousMultivariateDistribution + VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution -The [location scale] variational family broadly represents various variational +The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. It generally represents any distribution for which the sampling path can be @@ -18,14 +18,14 @@ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution scale ::S dist ::D - function VILocationScale(μ::AbstractVector{<:Real}, - L::Union{<:AbstractTriangular{<:Real}, + function VILocationScale(location::AbstractVector{<:Real}, + scale::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, - q_base::ContinuousUnivariateDistribution) + dist::ContinuousUnivariateDistribution) # Restricting all the arguments to have the same types creates problems # with dual-variable-based AD frameworks. - @assert (length(μ) == size(L,1)) && (length(μ) == size(L,2)) - new{typeof(μ), typeof(L), typeof(q_base)}(μ, L, q_base) + @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2)) + new{typeof(location), typeof(scale), typeof(dist)}(location, scale, dist) end end @@ -87,7 +87,7 @@ function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) whe end """ - VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. """ diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index e965ea73..e4e93327 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,9 +1,23 @@ """ - ADVI + ADVI( + prob, + n_samples::Int; + entropy::AbstractEntropyEstimator = ClosedFormEntropy(), + cv::Union{<:AbstractControlVariate, Nothing} = nothing, + b = Bijectors.identity + ) Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. +# Arguments +- `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. + - `logdensity` must be differentiable by the selected AD backend. +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. +- `entropy`: The estimator for the entropy term. +- `cv`: A control variate +- `b`: A bijector mapping the support of the base distribution to that of `prob`. + # Requirements - ``q_{\\lambda}`` implements `rand`. - ``\\pi`` must be differentiable @@ -39,40 +53,87 @@ end Base.show(io::IO, advi::ADVI) = print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))") -skip_entropy_gradient(advi::ADVI) = skip_entropy_gradient(advi.entropy) - init(advi::ADVI) = init(advi.cv) -function (advi::ADVI)(q_η::ContinuousMultivariateDistribution; - rng ::AbstractRNG = default_rng(), - n_samples ::Int = advi.n_samples, - ηs ::AbstractMatrix = rand(rng, q_η, n_samples), - q_η_entropy::ContinuousMultivariateDistribution = q_η) +function (advi::ADVI)( + rng::AbstractRNG, + q_η::ContinuousMultivariateDistribution, + ηs ::AbstractMatrix +) + n_samples = size(ηs, 2) 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples end - ℍ = advi.entropy(q_η_entropy, ηs) + ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ end -function estimate_gradient( +""" + (advi::ADVI)( + q_η::ContinuousMultivariateDistribution; + rng::AbstractRNG = Random.default_rng(), + n_samples::Int = advi.n_samples + ) + +Evaluate the ELBO using the ADVI formulation. + +# Arguments +- `q_η`: Variational approximation before applying a bijector (unconstrained support). +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. + +""" +function (advi::ADVI)( + q_η::ContinuousMultivariateDistribution; + rng::AbstractRNG = default_rng(), + n_samples::Int = advi.n_samples +) + ηs = rand(rng, q_η, n_samples) + advi(rng, q_η, ηs) +end + +function estimate_advi_gradient_maybe_stl!( rng::AbstractRNG, adbackend::AbstractADType, - advi::ADVI, - est_state, + advi::ADVI{P, B, StickingTheLandingEntropy, CV}, λ::Vector{<:Real}, restructure, - out::DiffResults.MutableDiffResult) - - # Gradient-stopping for computing the sticking-the-landing control variate - q_η_stop = skip_entropy_gradient(advi.entropy) ? restructure(λ) : nothing + out::DiffResults.MutableDiffResult +) where {P, B, CV} + q_η_stop = restructure(λ) + grad!(adbackend, λ, out) do λ′ + q_η = restructure(λ′) + ηs = rand(rng, q_η, advi.n_samples) + -advi(rng, q_η_stop, ηs) + end +end +function estimate_advi_gradient_maybe_stl!( + rng::AbstractRNG, + adbackend::AbstractADType, + advi::ADVI{P, B, <:Union{ClosedFormEntropy, FullMonteCarloEntropy}, CV}, + λ::Vector{<:Real}, + restructure, + out::DiffResults.MutableDiffResult +) where {P, B, CV} grad!(adbackend, λ, out) do λ′ q_η = restructure(λ′) - q_η_entropy = skip_entropy_gradient(advi.entropy) ? q_η_stop : q_η - -advi(q_η; rng, q_η_entropy) + ηs = rand(rng, q_η, advi.n_samples) + -advi(rng, q_η, ηs) end +end + +function estimate_gradient( + rng::AbstractRNG, + adbackend::AbstractADType, + advi::ADVI, + est_state, + λ::Vector{<:Real}, + restructure, + out::DiffResults.MutableDiffResult +) + estimate_advi_gradient_maybe_stl!( + rng, adbackend, advi, λ, restructure, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 994bdd4f..7f37b619 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -7,11 +7,9 @@ end skip_entropy_gradient(::ClosedFormEntropy) = false -struct MonteCarloEntropy{IsStickingTheLanding} <: AbstractEntropyEstimator end +abstract type MonteCarloEntropy <: AbstractEntropyEstimator end -MonteCarloEntropy() = MonteCarloEntropy{false}() - -Base.show(io::IO, entropy::MonteCarloEntropy{false}) = print(io, "MonteCarloEntropy()") +struct FullMonteCarloEntropy <: MonteCarloEntropy end """ StickingTheLandingEntropy() @@ -44,11 +42,8 @@ and higher variance when ``\\pi \\not\\approx q_{\\lambda}``. # Reference 1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. """ -StickingTheLandingEntropy() = MonteCarloEntropy{true}() - -skip_entropy_gradient(::MonteCarloEntropy{IsStickingTheLanding}) where {IsStickingTheLanding} = IsStickingTheLanding -Base.show(io::IO, entropy::MonteCarloEntropy{true}) = print(io, "StickingTheLandingEntropy()") +struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) n_samples = size(ηs, 2) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 40e5dace..2f19ca61 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -19,6 +19,8 @@ struct TestModel{M,L,S} end include("models/normallognormal.jl") +include("models/normal.jl") +include("models/utils.jl") @testset "advi" begin @testset "locscale" begin @@ -27,11 +29,13 @@ include("models/normallognormal.jl") (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, + :NormalMeanField => normal_meanfield, + :NormalFullRank => normal_fullrank, ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), - :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = MonteCarloEntropy()), + :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = FullMonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/models/normal.jl b/test/models/normal.jl new file mode 100644 index 00000000..a677af93 --- /dev/null +++ b/test/models/normal.jl @@ -0,0 +1,51 @@ + +struct TestMvNormal{M,S} + μ::M + Σ::S +end + +function LogDensityProblems.logdensity(model::TestMvNormal, θ) + @unpack μ, Σ = model + logpdf(MvNormal(μ, Σ), θ) +end + +function LogDensityProblems.dimension(model::TestMvNormal) + length(model.μ) +end + +function LogDensityProblems.capabilities(::Type{<:TestMvNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::TestMvNormal) + identity +end + +function normal_fullrank(realtype; rng = default_rng()) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + L₀ = sample_cholesky(rng, n_dims) + ϵ = eps(realtype)*10 + Σ = (L₀*L₀' + ϵ*I) |> Hermitian + + Σ_chol = cholesky(Σ) + model = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol)) + + L = Σ_chol.L |> LowerTriangular + + TestModel(model, μ, L, n_dims, false) +end + +function normal_meanfield(realtype; rng = default_rng()) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = TestMvNormal(μ, PDMats.PDiagMat(σ)) + + L = σ |> Diagonal + + TestModel(model, μ, L, n_dims, true) +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 18e8b4a3..ca8c9a4d 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,8 +32,8 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L₀_y = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - ϵ = realtype(n_dims*2) + L₀_y = sample_cholesky(rng, n_dims) + ϵ = eps(realtype)*10 Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) diff --git a/test/models/utils.jl b/test/models/utils.jl new file mode 100644 index 00000000..c1a9a407 --- /dev/null +++ b/test/models/utils.jl @@ -0,0 +1,8 @@ + +function sample_cholesky(rng::AbstractRNG, n_dims::Int) + A = randn(rng, n_dims, n_dims) + L = tril(A) + idx = diagind(L) + @. L[idx] = log(exp(L[idx]) + 1) + L |> LowerTriangular +end From b72c2585a1d3e461d9903884d16d9b019c11e828 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 16 Jul 2023 03:49:08 +0100 Subject: [PATCH 061/144] fix type bugs, relax test threshold for the exact inference tests --- test/advi_locscale.jl | 8 ++++---- test/models/normal.jl | 5 ++--- test/models/normallognormal.jl | 5 ++--- test/models/utils.jl | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 2f19ca61..1552be5e 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -72,7 +72,7 @@ include("models/utils.jl") Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-3), + optimizer = Optimisers.Adam(1e-2), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -82,7 +82,7 @@ include("models/utils.jl") L = q.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ₀/√T + @test Δλ ≤ Δλ₀/T^(1/4) @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end @@ -91,7 +91,7 @@ include("models/utils.jl") rng = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-3), + optimizer = Optimisers.Adam(realtype(1e-2)), progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -102,7 +102,7 @@ include("models/utils.jl") rng_repl = Philox4x(UInt64, seed, 8) q, stats = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-3), + optimizer = Optimisers.Adam(realtype(1e-2)), progress = PROGRESS, rng = rng_repl, adbackend = adbackend, diff --git a/test/models/normal.jl b/test/models/normal.jl index a677af93..f60ad5f3 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -25,9 +25,8 @@ function normal_fullrank(realtype; rng = default_rng()) n_dims = 5 μ = randn(rng, realtype, n_dims) - L₀ = sample_cholesky(rng, n_dims) - ϵ = eps(realtype)*10 - Σ = (L₀*L₀' + ϵ*I) |> Hermitian + L₀ = sample_cholesky(rng, realtype, n_dims) + Σ = L₀*L₀' |> Hermitian Σ_chol = cholesky(Σ) model = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol)) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index ca8c9a4d..cab73cce 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,9 +32,8 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L₀_y = sample_cholesky(rng, n_dims) - ϵ = eps(realtype)*10 - Σ_y = (L₀_y*L₀_y' + ϵ*I) |> Hermitian + L₀_y = sample_cholesky(rng, realtype, n_dims) + Σ_y = L₀_y*L₀_y' |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) diff --git a/test/models/utils.jl b/test/models/utils.jl index c1a9a407..3d483c46 100644 --- a/test/models/utils.jl +++ b/test/models/utils.jl @@ -1,6 +1,6 @@ -function sample_cholesky(rng::AbstractRNG, n_dims::Int) - A = randn(rng, n_dims, n_dims) +function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) + A = randn(rng, type, n_dims, n_dims) L = tril(A) idx = diagind(L) @. L[idx] = log(exp(L[idx]) + 1) From a8df9eb8b635e9805e3f307b7b5b64ccb4f1f970 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:33:15 +0100 Subject: [PATCH 062/144] refactor derivative utils to match NormalizingFlows.jl with extras --- Project.toml | 20 +++++++-- ext/AdvancedVIEnzymeExt.jl | 26 +++++++++++ ext/AdvancedVIForwardDiffExt.jl | 29 ++++++++++++ ext/AdvancedVIReverseDiffExt.jl | 23 ++++++++++ ext/AdvancedVIZygoteExt.jl | 24 ++++++++++ src/AdvancedVI.jl | 79 ++++++++++++++++++--------------- src/compat/enzyme.jl | 16 ------- src/compat/reversediff.jl | 19 -------- src/compat/zygote.jl | 13 ------ src/grad.jl | 30 ------------- src/objectives/elbo/advi.jl | 6 ++- test/ad.jl | 7 ++- 12 files changed, 167 insertions(+), 125 deletions(-) create mode 100644 ext/AdvancedVIEnzymeExt.jl create mode 100644 ext/AdvancedVIForwardDiffExt.jl create mode 100644 ext/AdvancedVIReverseDiffExt.jl create mode 100644 ext/AdvancedVIZygoteExt.jl delete mode 100644 src/compat/enzyme.jl delete mode 100644 src/compat/reversediff.jl delete mode 100644 src/compat/zygote.jl delete mode 100644 src/grad.jl diff --git a/Project.toml b/Project.toml index cf698f7a..ab00d674 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,10 @@ version = "0.2.4" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -21,24 +21,36 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" +DiffResults = "1.0.3" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" -ForwardDiff = "0.10.3" +ForwardDiff = "0.10.25" +LogDensityProblems = "2.1.1" +Optimisers = "0.2.16" ProgressMeter = "1.0.0" Requires = "0.5, 1.0" +ReverseDiff = "1.14" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" julia = "1.6" [extras] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 00000000..8333299f --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,26 @@ + +module AdvancedVIEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..Enzyme + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y = f(θ) + DiffResults.value!(out, y) + ∇θ = DiffResults.gradient(out) + fill!(∇θ, zero(T)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + return out +end + +end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl new file mode 100644 index 00000000..e6b03af2 --- /dev/null +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -0,0 +1,29 @@ + +module AdvancedVIForwardDiffExt + +if isdefined(Base, :get_extension) + using ForwardDiff + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..ForwardDiff + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# extract chunk size from AutoForwardDiff +getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + chunk_size = getchunksize(ad) + config = if isnothing(chunk_size) + ForwardDiff.GradientConfig(f, θ) + else + ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + end + ForwardDiff.gradient!(out, f, θ, config) + return out +end + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 00000000..fd7fbaab --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,23 @@ + +module AdvancedVIReverseDiffExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using ReverseDiff +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..ReverseDiff +end + +# ReverseDiff without compiled tape +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + tp = ReverseDiff.GradientTape(f, θ) + ReverseDiff.gradient!(out, tp, θ) + return out +end + +end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl new file mode 100644 index 00000000..b447d071 --- /dev/null +++ b/ext/AdvancedVIZygoteExt.jl @@ -0,0 +1,24 @@ + +module AdvancedVIZygoteExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Zygote +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Zygote +end + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y, back = Zygote.pullback(f, θ) + ∇θ = back(one(T)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, first(∇θ)) + return out +end + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9f93885c..697f3c83 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -21,9 +21,9 @@ using LinearAlgebra: AbstractTriangular using LogDensityProblems -using ADTypes +using ADTypes, DiffResults using ADTypes: AbstractADType -using ForwardDiff, Tracker + using FillArrays using PDMats @@ -34,29 +34,23 @@ using StatsBase: entropy const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) -using Requires -function __init__() - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - end -end - +# derivatives """ - grad!(f, λ, out) - -Computes the gradients of the objective f. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. + value_and_gradient!( + ad::ADTypes.AbstractADType, + f, + θ::AbstractVector{T}, + out::DiffResults.MutableDiffResult + ) where {T<:Real} + +Compute the value and gradient of a function `f` at `θ` using the automatic +differentiation backend `ad`. The result is stored in `out`. +The function `f` must return a scalar value. The gradient is stored in `out` as a +vector of the same length as `θ`. """ -function grad! end +function value_and_gradient! end -include("grad.jl") +export value_and_gradient! # estimators abstract type AbstractVariationalObjective end @@ -94,21 +88,8 @@ export VIFullRankGaussian, VIMeanFieldGaussian -""" - optimize(model, alg::VariationalInference) - optimize(model, alg::VariationalInference, q::VariationalPosterior) - optimize(model, alg::VariationalInference, getq::Function, θ::AbstractArray) - -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. - -# Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior -""" +# Optimization Routine + function optimize end include("optimize.jl") @@ -117,4 +98,28 @@ export optimize include("utils.jl") + +# optional dependencies +if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base + using Requires +end + +using Requires +function __init__() + @static if !isdefined(Base, :get_extension) + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") + end + @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIForwardDiffExt.jl") + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/AdvancedVIReverseDiffExt.jl") + end + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") + end + end +end end # module + diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl deleted file mode 100644 index cab50862..00000000 --- a/src/compat/enzyme.jl +++ /dev/null @@ -1,16 +0,0 @@ - -function AdvancedVI.grad!( - f::Function, - ::AutoEnzyme, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(λ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy)) - return out -end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl deleted file mode 100644 index 4d8f87d8..00000000 --- a/src/compat/reversediff.jl +++ /dev/null @@ -1,19 +0,0 @@ -using .ReverseDiff: compile, GradientTape -using .ReverseDiff.DiffResults: GradientResult - -tape(f, x) = GradientTape(f, x) -function taperesult(f, x) - return tape(f, x), GradientResult(x) -end - -# Precompiled tapes are not properly supported yet. -function AdvancedVI.grad!( - f::Function, - ::AutoReverseDiff, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - tp = tape(f, λ) - ReverseDiff.gradient!(out, tp, λ) - return out -end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl deleted file mode 100644 index f1a29b87..00000000 --- a/src/compat/zygote.jl +++ /dev/null @@ -1,13 +0,0 @@ - -function AdvancedVI.grad!( - f::Function, - ::AutoZygote, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - ) - y, back = Zygote.pullback(f, λ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out -end diff --git a/src/grad.jl b/src/grad.jl deleted file mode 100644 index e68e1623..00000000 --- a/src/grad.jl +++ /dev/null @@ -1,30 +0,0 @@ - -# default implementations -function grad!( - f::Function, - adtype::AutoForwardDiff{chunksize}, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) where {chunksize} - # Set chunk size and do ForwardMode. - config = if isnothing(chunksize) - ForwardDiff.GradientConfig(f, λ) - else - ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunksize)) - end - ForwardDiff.gradient!(out, f, λ, config) -end - -function grad!( - f::Function, - ::AutoTracker, - λ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult -) - λ_tracked = Tracker.param(λ) - y = f(λ_tracked) - Tracker.back!(y, 1.0) - - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(λ_tracked)) -end diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index e4e93327..d308db0a 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -101,11 +101,12 @@ function estimate_advi_gradient_maybe_stl!( out::DiffResults.MutableDiffResult ) where {P, B, CV} q_η_stop = restructure(λ) - grad!(adbackend, λ, out) do λ′ + f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η_stop, ηs) end + grad!(adbackend, f, λ, out) end function estimate_advi_gradient_maybe_stl!( @@ -116,11 +117,12 @@ function estimate_advi_gradient_maybe_stl!( restructure, out::DiffResults.MutableDiffResult ) where {P, B, CV} - grad!(adbackend, λ, out) do λ′ + f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η, ηs) end + value_and_gradient!(adbackend, f, λ, out) end function estimate_gradient( diff --git a/test/ad.jl b/test/ad.jl index 1efa536b..9df26d9f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -9,15 +9,14 @@ using ADTypes :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Tracker => AutoTracker(), - :Enzyme => AutoEnzyme(), + # :Enzyme => AutoEnzyme(), # Currently not tested against. ) D = 10 A = randn(D, D) λ = randn(D) grad_buf = DiffResults.GradientResult(λ) - AdvancedVI.grad!(adsymbol, λ, grad_buf) do λ′ - λ′'*A*λ′ / 2 - end + f(λ′) = λ′'*A*λ′ / 2 + AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A')*λ/2 From e8db6a7ac62d1916969aaaeea336677ba19eafa0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:34:14 +0100 Subject: [PATCH 063/144] add documentation, refactor optimize --- docs/Project.toml | 2 +- docs/make.jl | 11 ++-- docs/src/advi.md | 36 ++++++++++- docs/src/families.md | 32 ++++++---- src/objectives/elbo/entropy.jl | 32 ---------- src/optimize.jl | 106 +++++++++++++++++++++------------ 6 files changed, 130 insertions(+), 89 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index fc885857..c625d07f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,4 +4,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] -Documenter = "0.26" \ No newline at end of file +Documenter = "0.26, 0.27" \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index d2a01d1b..b9a8eb5f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,5 @@ -#using AdvancedVI + +using AdvancedVI using Documenter DocMeta.setdocmeta!( @@ -9,9 +10,9 @@ makedocs(; sitename = "AdvancedVI.jl", modules = [AdvancedVI], format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["index.md", - "families.md", - "advi.md"], + pages = ["Home" => "index.md", + "Families" => "families.md", + "ADVI" => "advi.md"], ) -deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", devbranch="main") +deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/advi.md b/docs/src/advi.md index 4f4a2eca..0597e03c 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -1,5 +1,8 @@ # [Automatic Differentiation Variational Inference](@id advi) + +# Introduction + The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. By maximizing ADVI objective, it is equivalent to solving the problem @@ -17,8 +20,8 @@ bijectors will automatically match the potentially constrained support of the ta In particular, ADVI implicitly forms a variational approximation ``q_{\phi,\lambda}`` from a reparameterizable distribution ``q_{\lambda}`` and a bijector ``\phi`` such that ```math -z &\sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad -z &\stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} +z \sim q_{\phi,\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} \phi^{-1}\left(\eta\right);\quad \eta \sim q_{\lambda} ``` ADVI provides a principled way to compute the evidence lower bound for ``q_{\phi,\lambda}``. @@ -53,15 +56,44 @@ coined by Titsias and Lázaro-Gredilla (2014). Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by Fjelde *et al.* (2017). +# The `ADVI` Objective ```@docs ADVI ``` +# The "Sticking the Landing" Control Variate +The STL control variate was proposed by Roeder *et al.* (2017). +By slightly modifying the differentiation path, it implicitly forms a control variate of the form of +```math +\mathrm{CV}_{\mathrm{STL}}\left(z\right) \triangleq \mathbb{H}\left(q_{\lambda}\right) + \log q_{\lambda}\left(z\right), +``` +which has a mean of zero. + +Adding this to the closed-form entropy ELBO estimator yields the STL estimator: +```math +\begin{aligned} + \widehat{\mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) + &\triangleq \mathbb{E}\left[ \log \pi \left(z\right) \right] - \log q_{\lambda} \left(z\right) \\ + &= \mathbb{E}\left[ \log \pi\left(z\right) \right] + + \mathbb{H}\left(q_{\lambda}\right) - \mathrm{CV}_{\mathrm{STL}}\left(z\right) \\ + &= \widehat{\mathrm{ELBO}}\left(\lambda\right) + - \mathrm{CV}_{\mathrm{STL}}\left(z\right), +\end{aligned} +``` +which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. +The conditions for which the STL estimator results in lower variance is still an active subject for research. + +The STL control variate can be used by changing the entropy estimator as follows: +```julia +ADVI(prob, n_samples; entropy = StickingTheLanding(), b = bijector) +``` + # References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. 3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. 4. Fjelde, T. E., Xu, K., Tarek, M., Yalburgi, S., & Ge, H. (2020, February). Bijectors. jl: Flexible transformations for probability distributions. In Symposium on Advances in Approximate Bayesian Inference (pp. 1-17). PMLR. +5. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. diff --git a/docs/src/families.md b/docs/src/families.md index f203cf18..d326ce7a 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -1,5 +1,5 @@ -# [Variational Families](@id families) +# Variational Families ## Location-Scale Variational Family @@ -25,34 +25,42 @@ VIMeanFieldGaussian ### Examples -A full-rank variational family can be formed by choosing ```@repl locscale -using AdvancedVI, LinearAlgebra +using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); -L = diagm(ones(2)) |> LowerTriangular; -``` - -A mean-field variational family can be formed by choosing -```@repl locscale -μ = zeros(2); -L = ones(2) |> Diagonal; ``` Gaussian variational family: ```@repl locscale +L = diagm(ones(2)) |> LowerTriangular; q = VIFullRankGaussian(μ, L) + +L = ones(2) |> Diagonal; q = VIMeanFieldGaussian(μ, L) ``` Sudent-T Variational Family: ```@repl locscale -ν = 3 -q = VILocationScale(μ, L, StudentT(ν)) +ν = 3; + +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = VILocationScale(μ, L, TDist(ν)) + +# Mean-Field +L = ones(2) |> Diagonal; +q = VILocationScale(μ, L, TDist(ν)) ``` Multivariate Laplace family: ```@repl locscale +# Full-Rank +L = diagm(ones(2)) |> LowerTriangular; +q = VILocationScale(μ, L, Laplace()) + +# Mean-Field +L = ones(2) |> Diagonal; q = VILocationScale(μ, L, Laplace()) ``` diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 7f37b619..e9f180f5 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -11,38 +11,6 @@ abstract type MonteCarloEntropy <: AbstractEntropyEstimator end struct FullMonteCarloEntropy <: MonteCarloEntropy end -""" - StickingTheLandingEntropy() - -# Explanation - -The STL estimator forms a control variate of the form of - -```math -\\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) = - \\mathbb{E}\\left[ -\\log q\\left(z\\right) \\right] - + \\log q\\left(z\\right) = \\mathbb{H}\\left(q_{\\lambda}\\right) + \\log q_{\\lambda}\\left(z\\right), -``` -where, for the score term, the gradient is stopped from propagating. - -Adding this to the closed-form entropy ELBO estimator yields the STL estimator: -```math -\\begin{aligned} - \\widehat{\\mathrm{ELBO}}_{\\mathrm{STL}}\\left(\\lambda\\right) - &\\triangleq \\mathbb{E}\\left[ \\log \\pi \\left(z\\right) \\right] - \\log q_{\\lambda} \\left(z\\right) \\\\ - &= \\mathbb{E}\\left[ \\log \\pi\\left(z\\right) \\right] - + \\mathbb{H}\\left(q_{\\lambda}\\right) - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right) \\\\ - &= \\widehat{\\mathrm{ELBO}}\\left(\\lambda\\right) - - \\mathrm{CV}_{\\mathrm{STL}}\\left(z\\right), -\\end{aligned} -``` -which has the same expectation, but lower variance when ``\\pi \\approx q_{\\lambda}``, -and higher variance when ``\\pi \\not\\approx q_{\\lambda}``. - -# Reference -1. Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. -""" - struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) diff --git a/src/optimize.jl b/src/optimize.jl index 8b36df04..ef16dcce 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -4,73 +4,105 @@ function pm_next!(pm, stats::NamedTuple) end """ - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) + optimize( + objective ::AbstractVariationalObjective, + restructure, + λ₀ ::AbstractVector{<:Real}, + n_max_iter ::Int; + kwargs... + ) -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. +Optimize the variational objective `objective` by estimating (stochastic) gradients, where the variational approximation can be constructed by passing the variational parameters `λ₀` to the function `restructure`. + + optimize( + objective ::AbstractVariationalObjective, + q, + n_max_iter::Int; + kwargs... + ) + +Optimize the variational objective `objective` by estimating (stochastic) gradients, where the initial variational approximation `q₀` supports the `Optimisers.destructure` interface. + +# Arguments +- `objective`: Variational Objective. +- `λ₀`: Initial value of the variational parameters. +- `restructure`: Function that reconstructs the variational approximation from the flattened parameters. +- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. +- `n_max_iter`: Maximum number of iterations. + +# Keyword Arguments +- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) +- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) +- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) +- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) +- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) + +# Returns +- `λ`: Variational parameters optimizing the variational objective. +- `stats`: Statistics gathered during inference. +- `opt_state`: Final state of the optimiser. """ function optimize( - objective ::AbstractVariationalObjective, + objective ::AbstractVariationalObjective, restructure, - λ ::AbstractVector{<:Real}, - n_max_iter::Int; - optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), - rng ::AbstractRNG = default_rng(), - progress ::Bool = true, - callback! = nothing, - terminate = (args...) -> false, - adbackend::AbstractADType = AutoForwardDiff(), + λ₀ ::AbstractVector{<:Real}, + n_max_iter ::Int; + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + rng ::AbstractRNG = default_rng(), + show_progress::Bool = true, + callback! = nothing, + #convergence = (args...) -> (false, con_state), + adbackend::AbstractADType = AutoForwardDiff(), + prog = ProgressMeter.Progress( + n_max_iter; + desc = "Optimizing", + barlen = 31, + showspeed = true, + enabled = show_progress + ) ) - opt_state = Optimisers.init(optimizer, λ) + λ = copy(λ₀) + opt_state = Optimisers.setup(optimizer, λ) est_state = init(objective) + #con_state = init(convergence) grad_buf = DiffResults.GradientResult(λ) - - prog = ProgressMeter.Progress(n_max_iter; - barlen = 0, - enabled = progress, - showspeed = true) - stats = Vector{NamedTuple}(undef, n_max_iter) + stats = NamedTuple[] for t = 1:n_max_iter stat = (iteration=t,) grad_buf, est_state, stat′ = estimate_gradient( rng, adbackend, objective, est_state, λ, restructure, grad_buf) - g = DiffResults.gradient(grad_buf) stat = merge(stat, stat′) - opt_state, Δλ = Optimisers.apply!(optimizer, opt_state, λ, g) - Optimisers.subtract!(λ, Δλ) - - stat′ = (iteration=t, Δλ=norm(Δλ), gradient_norm=norm(g)) + g = DiffResults.gradient(grad_buf) + opt_state, λ = Optimisers.update!(opt_state, λ, g) + stat′ = (iteration=t, gradient_norm=norm(g)) stat = merge(stat, stat′) - q = restructure(λ) - if !isnothing(callback!) - stat′ = callback!(q, stat) + stat′ = callback!(; est_state, stat, restructure, λ) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end AdvancedVI.DEBUG && @debug "Step $t" stat... pm_next!(prog, stat) - stats[t] = stat + push!(stats, stat) - # Termination decision is work in progress - if terminate(rng, λ, q, objective, stat) - stats = stats[1:t] - break - end + #convergence(rng, t, restructure, λ, q, objective, stat) + #if terminate() + # break + #end end - λ, stats + λ, map(identity, stats), opt_state end -function optimize(objective::AbstractVariationalObjective, - q, +function optimize(objective ::AbstractVariationalObjective, + q₀, n_max_iter::Int; kwargs...) - λ, restructure = Optimisers.destructure(q) + λ, restructure = Optimisers.destructure(q₀) λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...) restructure(λ), stats end From 65a2b37d354798dd40161dc897f7124b5b68b857 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:49:57 +0100 Subject: [PATCH 064/144] fix bug missing extension --- Project.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Project.toml b/Project.toml index ab00d674..ffc41a4b 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,12 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[extensions] +AdvancedVIEnzymeExt = "Enzyme" +AdvancedVIForwardDiffExt = "ForwardDiff" +AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVIZygoteExt = "Zygote" + [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" From 1a02051f6fb8e2c59b39e7faa58c91db7ca589b3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:50:24 +0100 Subject: [PATCH 065/144] remove tracker from tests --- test/ad.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 9df26d9f..2c4f802a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,6 @@ using ReTest -using ForwardDiff, ReverseDiff, Tracker, Enzyme, Zygote +using ForwardDiff, ReverseDiff, Enzyme, Zygote using ADTypes @testset "ad" begin @@ -8,7 +8,6 @@ using ADTypes :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Tracker => AutoTracker(), # :Enzyme => AutoEnzyme(), # Currently not tested against. ) D = 10 From d8b5ea5a153e5a484972c8c46e98a58e0b958b95 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 01:50:50 +0100 Subject: [PATCH 066/144] remove export for internal derivative utils --- src/AdvancedVI.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 697f3c83..a1cf360a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -50,8 +50,6 @@ vector of the same length as `θ`. """ function value_and_gradient! end -export value_and_gradient! - # estimators abstract type AbstractVariationalObjective end @@ -104,11 +102,10 @@ if !isdefined(Base, :get_extension) # check whether :get_extension is defined in using Requires end -using Requires function __init__() @static if !isdefined(Base, :get_extension) - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/AdvancedVIZygoteExt.jl") + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") end @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/AdvancedVIForwardDiffExt.jl") @@ -116,10 +113,11 @@ function __init__() @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin include("../ext/AdvancedVIReverseDiffExt.jl") end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("../ext/AdvancedVIEnzymeExt.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") end end end -end # module + +end From 818bc2c33fb7513681c06bb6a99cf341c97957dc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:28:47 +0100 Subject: [PATCH 067/144] fix test errors, old interface --- src/optimize.jl | 4 ++-- test/advi_locscale.jl | 36 ++++++++++++++++++------------------ test/runtests.jl | 4 +++- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index ef16dcce..7c876b39 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -103,6 +103,6 @@ function optimize(objective ::AbstractVariationalObjective, n_max_iter::Int; kwargs...) λ, restructure = Optimisers.destructure(q₀) - λ, stats = optimize(objective, restructure, λ, n_max_iter; kwargs...) - restructure(λ), stats + λ, stats, opt_state = optimize(objective, restructure, λ, n_max_iter; kwargs...) + restructure(λ), stats, opt_state end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 1552be5e..d4ef7aec 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -69,13 +69,13 @@ include("models/utils.jl") obj = objective(model, b⁻¹, 10) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats = optimize( + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-2), - progress = PROGRESS, - rng = rng, - adbackend = adbackend, + optimizer = Optimisers.Adam(1e-2), + show_progress = PROGRESS, + rng = rng, + adbackend = adbackend, ) μ = q.location @@ -88,24 +88,24 @@ include("models/utils.jl") end @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats = optimize( + rng = Philox4x(UInt64, seed, 8) + q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), - progress = PROGRESS, - rng = rng, - adbackend = adbackend, + optimizer = Optimisers.Adam(realtype(1e-2)), + show_progress = PROGRESS, + rng = rng, + adbackend = adbackend, ) μ = q.location L = q.scale - rng_repl = Philox4x(UInt64, seed, 8) - q, stats = optimize( + rng_repl = Philox4x(UInt64, seed, 8) + q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), - progress = PROGRESS, - rng = rng_repl, - adbackend = adbackend, + optimizer = Optimisers.Adam(realtype(1e-2)), + show_progress = PROGRESS, + rng = rng_repl, + adbackend = adbackend, ) μ_repl = q.location L_repl = q.scale diff --git a/test/runtests.jl b/test/runtests.jl index ddc1d09c..68225fd9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ -using Comonicon +using ReTest using ReTest: @testset, @test + +using Comonicon using Random using Random123 using Statistics From 215abf34639e76b59d3d8b7ad1b64d24ec7500e0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:29:06 +0100 Subject: [PATCH 068/144] fix wrong derivative interface, add documentation --- src/objectives/elbo/advi.jl | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index d308db0a..8bc14bc9 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -1,26 +1,21 @@ """ - ADVI( - prob, - n_samples::Int; - entropy::AbstractEntropyEstimator = ClosedFormEntropy(), - cv::Union{<:AbstractControlVariate, Nothing} = nothing, - b = Bijectors.identity - ) + ADVI(prob, n_samples; kwargs...) Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective. # Arguments - `prob`: An object that implements the order `K == 0` `LogDensityProblems` interface. - - `logdensity` must be differentiable by the selected AD backend. -- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. -- `entropy`: The estimator for the entropy term. -- `cv`: A control variate -- `b`: A bijector mapping the support of the base distribution to that of `prob`. +- `n_samples`: Number of Monte Carlo samples used to estimate the ELBO. (Type `<: Int`.) + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) +- `cv`: A control variate. +- `b`: A bijector mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) # Requirements - ``q_{\\lambda}`` implements `rand`. -- ``\\pi`` must be differentiable +- `logdensity(prob)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ @@ -106,7 +101,7 @@ function estimate_advi_gradient_maybe_stl!( ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η_stop, ηs) end - grad!(adbackend, f, λ, out) + value_and_gradient!(adbackend, f, λ, out) end function estimate_advi_gradient_maybe_stl!( From 88ad7680a928932be97e1f075d5cd1c0d497a651 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:29:25 +0100 Subject: [PATCH 069/144] update documentation --- docs/src/advi.md | 17 ++++++++----- docs/src/families.md | 44 +++++++++++++++++++++------------- src/objectives/elbo/entropy.jl | 9 +++++++ 3 files changed, 48 insertions(+), 22 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 0597e03c..37b3541b 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -1,7 +1,7 @@ # [Automatic Differentiation Variational Inference](@id advi) -# Introduction +## Introduction The automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) objective is a method for estimating the evidence lower bound between a target posterior distribution ``\pi`` and a variational approximation ``q_{\phi,\lambda}``. By maximizing ADVI objective, it is equivalent to solving the problem @@ -56,13 +56,13 @@ coined by Titsias and Lázaro-Gredilla (2014). Bijectors were generalized by Dillon *et al.* (2017) and later implemented in Julia by Fjelde *et al.* (2017). -# The `ADVI` Objective +## The `ADVI` Objective ```@docs ADVI ``` -# The "Sticking the Landing" Control Variate +## The `StickingTheLanding` Control Variate The STL control variate was proposed by Roeder *et al.* (2017). By slightly modifying the differentiation path, it implicitly forms a control variate of the form of ```math @@ -84,12 +84,17 @@ Adding this to the closed-form entropy ELBO estimator yields the STL estimator: which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. The conditions for which the STL estimator results in lower variance is still an active subject for research. -The STL control variate can be used by changing the entropy estimator as follows: +The STL control variate can be used by changing the entropy estimator using the following object: +```@docs +StickingTheLandingEntropy +``` + +For example: ```julia -ADVI(prob, n_samples; entropy = StickingTheLanding(), b = bijector) +ADVI(prob, n_samples; entropy = StickingTheLandingEntropy(), b = bijector) ``` -# References +## References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. 3. Dillon, J. V., Langmore, I., Tran, D., Brevdo, E., Vasudevan, S., Moore, D., ... & Saurous, R. A. (2017). Tensorflow distributions. arXiv preprint arXiv:1711.10604. diff --git a/docs/src/families.md b/docs/src/families.md index d326ce7a..e6eaa91b 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -1,18 +1,26 @@ -# Variational Families +# Location-Scale Variational Family -## Location-Scale Variational Family - -### Description +## Description The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as ```math -z = C u + m, +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} z = C u + m;\quad u \sim \varphi ``` -where ``C`` is the *scale* and ``m`` is the location variational parameter. -This family encompases many - +where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. +``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. +The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. +The probability density is given by +```math + q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) +``` +and the entropy is given as +```math + \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|, +``` +where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. -### Constructors +## Constructors ```@docs VILocationScale @@ -23,15 +31,13 @@ VIFullRankGaussian VIMeanFieldGaussian ``` -### Examples +## Gaussian Variational Families -```@repl locscale +Gaussian variational family: +```julia using AdvancedVI, LinearAlgebra, Distributions; μ = zeros(2); -``` -Gaussian variational family: -```@repl locscale L = diagm(ones(2)) |> LowerTriangular; q = VIFullRankGaussian(μ, L) @@ -39,9 +45,12 @@ L = ones(2) |> Diagonal; q = VIMeanFieldGaussian(μ, L) ``` +## Non-Gaussian Variational Families Sudent-T Variational Family: -```@repl locscale +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); ν = 3; # Full-Rank @@ -54,7 +63,10 @@ q = VILocationScale(μ, L, TDist(ν)) ``` Multivariate Laplace family: -```@repl locscale +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + # Full-Rank L = diagm(ones(2)) |> LowerTriangular; q = VILocationScale(μ, L, Laplace()) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index e9f180f5..0edc47f4 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -11,6 +11,15 @@ abstract type MonteCarloEntropy <: AbstractEntropyEstimator end struct FullMonteCarloEntropy <: MonteCarloEntropy end +""" + StickingTheLandingEntropy() + +The "sticking the landing" entropy estimator. + +# Requirements +- `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +""" struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) From e66935bb2881a61cf137ff74899e7117c53a9f46 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:34:11 +0100 Subject: [PATCH 070/144] add doc build CI --- .github/workflows/CI.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9731f20c..158da963 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,3 +61,30 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info + docs: + name: Documentation + runs-on: ubuntu-latest + permissions: + contents: write + statuses: write + steps: + - uses: actions/checkout@v3 + - uses: julia-actions/setup-julia@v1 + with: + version: '1' + - name: Configure doc environment + run: | + julia --project=docs/ -e ' + using Pkg + Pkg.develop(PackageSpec(path=pwd())) + Pkg.instantiate()' + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-docdeploy@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - run: | + julia --project=docs -e ' + using Documenter: DocMeta, doctest + using AdvancedVI + DocMeta.setdocmeta!(AdvancedVI, :DocTestSetup, :(using AdvancedVI); recursive=true) + doctest(AdvancedVI)' From 9f1c647a6fb2b945754e808dcb608e3f19c4cae8 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:47:56 +0100 Subject: [PATCH 071/144] remove convergence criterion for now --- docs/src/families.md | 2 +- src/optimize.jl | 7 ------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/docs/src/families.md b/docs/src/families.md index e6eaa91b..8ae48be3 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -5,7 +5,7 @@ The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} z = C u + m;\quad u \sim \varphi +z \stackrel{d}{=} C u + m;\quad u \sim \varphi ``` where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. ``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. diff --git a/src/optimize.jl b/src/optimize.jl index 7c876b39..0f2d29e9 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -51,7 +51,6 @@ function optimize( rng ::AbstractRNG = default_rng(), show_progress::Bool = true, callback! = nothing, - #convergence = (args...) -> (false, con_state), adbackend::AbstractADType = AutoForwardDiff(), prog = ProgressMeter.Progress( n_max_iter; @@ -64,7 +63,6 @@ function optimize( λ = copy(λ₀) opt_state = Optimisers.setup(optimizer, λ) est_state = init(objective) - #con_state = init(convergence) grad_buf = DiffResults.GradientResult(λ) stats = NamedTuple[] @@ -89,11 +87,6 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) - - #convergence(rng, t, restructure, λ, q, objective, stat) - #if terminate() - # break - #end end λ, map(identity, stats), opt_state end From c8b3ee3ed7ec43051631462b7674a7c1d66722d7 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 02:54:12 +0100 Subject: [PATCH 072/144] remove outdated export --- src/AdvancedVI.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a1cf360a..1677be62 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -72,7 +72,6 @@ include("objectives/elbo/advi.jl") export ELBO, ADVI, - ADVIEnergy, ClosedFormEntropy, StickingTheLandingEntropy, FullMonteCarloEntropy From afda1a19527f4197b25a50fcae8e52cdeace660b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 20:53:42 +0100 Subject: [PATCH 073/144] update documentation --- docs/make.jl | 9 +++-- docs/src/index.md | 16 ++++----- docs/src/{families.md => locscale.md} | 4 +-- docs/src/started.md | 51 +++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 13 deletions(-) rename docs/src/{families.md => locscale.md} (96%) create mode 100644 docs/src/started.md diff --git a/docs/make.jl b/docs/make.jl index b9a8eb5f..ca21b5fd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,9 +10,12 @@ makedocs(; sitename = "AdvancedVI.jl", modules = [AdvancedVI], format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["Home" => "index.md", - "Families" => "families.md", - "ADVI" => "advi.md"], + pages = ["AdvancedVI" => "index.md", + "Getting Started" => "started.md", + "ELBO Maximization" => [ + "Automatic Differentiation VI" => "advi.md", + "Location Scale Family" => "locscale.md", + ]], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) diff --git a/docs/src/index.md b/docs/src/index.md index be326921..dea6d405 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,11 +4,11 @@ CurrentModule = AdvancedVI # AdvancedVI -Documentation for [AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl). - -```@index -``` - -```@autodocs -Modules = [AdvancedVI] -``` +## Introduction +[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. +VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. +`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. + +## Provided Algorithms +`AdvancedVI` currently provides the following algorithm for evidence lower bound maximization: +- [Automatic Differentiation Variational Inference](@ref advi) diff --git a/docs/src/families.md b/docs/src/locscale.md similarity index 96% rename from docs/src/families.md rename to docs/src/locscale.md index 8ae48be3..a4bc2dc1 100644 --- a/docs/src/families.md +++ b/docs/src/locscale.md @@ -1,7 +1,7 @@ -# Location-Scale Variational Family +# [Location-Scale Variational Family](@id locscale) -## Description +## Introduction The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as ```math z \sim q_{\lambda} \qquad\Leftrightarrow\qquad diff --git a/docs/src/started.md b/docs/src/started.md new file mode 100644 index 00000000..faff6166 --- /dev/null +++ b/docs/src/started.md @@ -0,0 +1,51 @@ + +# [Getting Started with `AdvancedVI`](@id getting_started) + +## General Usage +Each VI algorithm should provide the following: +1. A variational family +2. A variational objective + +Feeding these two into `optimize` runs the inference procedure. + +```@docs +optimize +``` + +## `ADVI` Example Using `Turing` + +```julia +using Turing +using Bijectors +using Optimisers +using ForwardDiff +using ADTypes + +import AdvancedVI as AVI + +μ_y, σ_y = 1.0, 1.0 +μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0] + +Turing.@model function normallognormal() + y ~ LogNormal(μ_y, σ_y) + z ~ MvNormal(μ_z, Σ_z) +end +model = normallognormal() +b = Bijectors.bijector(model) +b⁻¹ = inverse(b) +prob = DynamicPPL.LogDensityFunction(model) +d = LogDensityProblems.dimension(prob) + +μ = randn(d) +L = Diagonal(ones(d)) +q = AVI.MeanFieldGaussian(μ, L) + +n_max_iter = 10^4 +q, stats = AVI.optimize( + AVI.ADVI(prob, b⁻¹, 10), + q, + n_max_iter; + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +) +``` From 0d37acea1dd96c95d7cef427be7d84fee8d95c09 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 21:12:02 +0100 Subject: [PATCH 074/144] update documentation --- docs/src/started.md | 55 ++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/docs/src/started.md b/docs/src/started.md index faff6166..26c75a79 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -2,11 +2,13 @@ # [Getting Started with `AdvancedVI`](@id getting_started) ## General Usage -Each VI algorithm should provide the following: -1. A variational family -2. A variational objective +Each VI algorithm provides the followings: +1. Variational families supported by each VI algorithm. +2. A variational objective corresponding to the VI algorithm. +Note that each variational family is subject to its own constraints. +Thus, please refer to the documentation of the variational inference algorithm of interest. -Feeding these two into `optimize` runs the inference procedure. +To use `AdvancedVI`, a user needs to select a `variational family`, `variational objective`, and feed them into `optimize`. ```@docs optimize @@ -14,14 +16,10 @@ optimize ## `ADVI` Example Using `Turing` +In this tutorial, we'll use `Turing` to define a basic `normal-log-normal` model. +ADVI with log bijectors is able to infer this model exactly. ```julia using Turing -using Bijectors -using Optimisers -using ForwardDiff -using ADTypes - -import AdvancedVI as AVI μ_y, σ_y = 1.0, 1.0 μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0] @@ -31,18 +29,43 @@ Turing.@model function normallognormal() z ~ MvNormal(μ_z, Σ_z) end model = normallognormal() +``` + +Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. +Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. +```julia +using Bijectors + b = Bijectors.bijector(model) b⁻¹ = inverse(b) -prob = DynamicPPL.LogDensityFunction(model) -d = LogDensityProblems.dimension(prob) +``` +Let's now load `AdvancedVI`. +Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`. +Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. +Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. +```julia +using Optimisers +using ForwardDiff +import AdvancedVI as AVI +``` +We now need to select 1. a variational objective, and 2. a variational family. +Here, we will use the [ADVI objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. +```julia +prob = DynamicPPL.LogDensityFunction(model) +objective = AVI.ADVI(prob, b⁻¹, 10), +``` +For the variational family, we will use the classic mean-field Gaussian family. +```julia +d = LogDensityProblems.dimension(prob) μ = randn(d) L = Diagonal(ones(d)) -q = AVI.MeanFieldGaussian(μ, L) - +q = AVI.VIMeanFieldGaussian(μ, L) +``` +It now remains to run inverence! +``` n_max_iter = 10^4 -q, stats = AVI.optimize( - AVI.ADVI(prob, b⁻¹, 10), +q, stats = AVI.optimize( q, n_max_iter; adbackend = AutoForwardDiff(), From b8b113da2b3a64395e9daaf2bbb64e9b0b602a4e Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 13 Aug 2023 21:14:17 +0100 Subject: [PATCH 075/144] update documentation --- docs/src/started.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/started.md b/docs/src/started.md index 26c75a79..355e9350 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -50,10 +50,11 @@ using ForwardDiff import AdvancedVI as AVI ``` We now need to select 1. a variational objective, and 2. a variational family. -Here, we will use the [ADVI objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. +Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```julia -prob = DynamicPPL.LogDensityFunction(model) -objective = AVI.ADVI(prob, b⁻¹, 10), +prob = DynamicPPL.LogDensityFunction(model)] +n_montecaro = 10 +objective = AVI.ADVI(prob, b⁻¹, n_montecaro), ``` For the variational family, we will use the classic mean-field Gaussian family. ```julia From b78e713eaf46d3caab540e1d818be9930bea54dc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 16 Aug 2023 23:35:23 +0100 Subject: [PATCH 076/144] fix type error in test --- test/distributions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions.jl b/test/distributions.jl index 073fff64..9b18d020 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -11,7 +11,7 @@ using Distributions: _logpdf seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) rng = Philox4x(UInt64, seed, 8) realtype = Float64 - ϵ = 1e-2 + ϵ = 1f-2 n_dims = 10 n_montecarlo = 1000_000 From a0564b56bbe86b5885c333aa7fe2ca0e48fa0b24 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 16 Aug 2023 23:35:29 +0100 Subject: [PATCH 077/144] remove default ADType argument --- Project.toml | 2 +- src/optimize.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index ffc41a4b..35650ae5 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" -DiffResults = "1.0.3" +DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10.25" diff --git a/src/optimize.jl b/src/optimize.jl index 0f2d29e9..93e6f754 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -31,6 +31,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `n_max_iter`: Maximum number of iterations. # Keyword Arguments +- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.) - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) @@ -47,11 +48,11 @@ function optimize( restructure, λ₀ ::AbstractVector{<:Real}, n_max_iter ::Int; + adbackend::AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::AbstractRNG = default_rng(), show_progress::Bool = true, callback! = nothing, - adbackend::AbstractADType = AutoForwardDiff(), prog = ProgressMeter.Progress( n_max_iter; desc = "Optimizing", From 3795d1e05f510887df1c2900ab9f7638797ecc87 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:01:52 +0100 Subject: [PATCH 078/144] update README --- README.md | 304 +++++++++++++++--------------------------------------- 1 file changed, 81 insertions(+), 223 deletions(-) diff --git a/README.md b/README.md index 18ba63e5..e8718e7c 100644 --- a/README.md +++ b/README.md @@ -1,250 +1,108 @@ -# AdvancedVI.jl -A library for variational Bayesian inference in Julia. - -At the time of writing (05/02/2020), implementations of the variational inference (VI) interface and some algorithms are implemented in [Turing.jl](https://github.com/TuringLang/Turing.jl). The idea is to soon separate the VI functionality in Turing.jl out and into this package. - -The purpose of this package will then be to provide a common interface together with implementations of standard algorithms and utilities with the goal of ease of use and the ability for other packages, e.g. Turing.jl, to write a light wrapper around AdvancedVI.jl for integration. -As an example, in Turing.jl we support automatic differentiation variational inference (ADVI) but really the only piece of code tied into the Turing.jl is the conversion of a `Turing.Model` to a `logjoint(z)` function which computes `z ↦ log p(x, z)`, with `x` denoting the observations embedded in the `Turing.Model`. As long as this `logjoint(z)` method is compatible with some AD framework, e.g. `ForwardDiff.jl` or `Zygote.jl`, this is all we need from Turing.jl to be able to perform ADVI! - -## [WIP] Interface -- `vi`: the main interface to the functionality in this package - - `vi(model, alg)`: only used when `alg` has a default variational posterior which it will provide. - - `vi(model, alg, q::VariationalPosterior, θ)`: `q` represents the family of variational distributions and `θ` is the initial parameters "indexing" the starting distribution. This assumes that there exists an implementation `Variational.update(q, θ)` which returns the variational posterior corresponding to parameters `θ`. - - `vi(model, alg, getq::Function, θ)`: here `getq(θ)` is a function returning a `VariationalPosterior` corresponding to `θ`. -- `optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())` -- `grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)` - - Different combinations of variational objectives (`vo`), VI methods (`alg`), and variational posteriors (`q`) might use different gradient estimators. `grad!` allows us to specify these different behaviors. +# AdvancedVI.jl +[AdvancedVI](https://github.com/TuringLang/AdvancedVI.jl) provides implementations of variational Bayesian inference (VI) algorithms. +VI algorithms perform scalable and computationally efficient Bayesian inference at the cost of asymptotic exactness. +`AdvancedVI` is part of the [Turing](https://turinglang.org/stable/) probabilistic programming ecosystem. +The purpose of this package is to provide a common accessible interface for various VI algorithms and utilities so that other packages, e.g. `Turing`, only need to write a light wrapper for integration. +For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bijectors`](https://github.com/TuringLang/Bijectors.jl) by simply converting a `Turing.Model` into a [`LogDensityProblem`](https://github.com/tpapp/LogDensityProblems.jl) and extracting a corresponding `Bijectors.bijector`. ## Examples -### Variational Inference -A very simple generative model is the following - - μ ~ 𝒩(0, 1) - xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n - -where μ and xᵢ are some ℝᵈ vectors and 𝒩 denotes a d-dimensional multivariate Normal distribution. - -Given a set of `n` observations `[x₁, …, xₙ]` we're interested in finding the distribution `p(μ∣x₁, …, xₙ)` over the mean `μ`. We can obtain (an approximation to) this distribution that using AdvancedVI.jl! - -First we generate some observations and set up the problem: -```julia -julia> using Distributions - -julia> d = 2; n = 100; - -julia> observations = randn((d, n)); # 100 observations from 2D 𝒩(0, 1) - -julia> # Define generative model - # μ ~ 𝒩(0, 1) - # xᵢ ∼ 𝒩(μ, 1) , ∀i = 1, …, n - prior(μ) = logpdf(MvNormal(ones(d)), μ) -prior (generic function with 1 method) - -julia> likelihood(x, μ) = sum(logpdf(MvNormal(μ, ones(d)), x)) -likelihood (generic function with 1 method) - -julia> logπ(μ) = likelihood(observations, μ) + prior(μ) -logπ (generic function with 1 method) - -julia> logπ(randn(2)) # <= just checking that it works --311.74132761437653 -``` -Now there are mainly two different ways of specifying the approximate posterior (and its family). The first is by providing a mapping from distribution parameters to the distribution `θ ↦ q(⋅∣θ)`: -```julia -julia> using DistributionsAD, AdvancedVI - -julia> # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:d], exp.(θ[d + 1:4])) -getq (generic function with 1 method) -``` -Then we make the choice of algorithm, a subtype of `VariationalInference`, -```julia -julia> # Perform VI - advi = ADVI(10, 10_000) -ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 10000) -``` -And finally we can perform VI! The usual inferface is to call `vi` which behind the scenes takes care of the optimization and returns the resulting variational posterior: -```julia -julia> q = vi(logπ, advi, getq, randn(4)) -[ADVI] Optimizing...100% Time: 0:00:01 -TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[0.16282745378074515, 0.15789310089462574], σ=[0.09519377533754399, 0.09273176907111745]) -``` -Let's have a look at the resulting ELBO: -```julia -julia> AdvancedVI.elbo(advi, q, logπ, 1000) --287.7866366886285 -``` -Unfortunately, the *final* value of the ELBO is not always a very good diagnostic, though the ELBO is an important metric to keep an eye on during training since an *increase* in the ELBO means we're going in the right direction. Luckily, this is such a simple problem that we can indeed obtain a closed form solution! Because we're lazy (at least I am), we'll let [ConjugatePriors.jl](https://github.com/JuliaStats/ConjugatePriors.jl) do this for us: -```julia -julia> # True posterior - using ConjugatePriors -julia> pri = MvNormal(zeros(2), ones(2)); +`AdvancedVI` basically expects a `LogDensityProblem`. +For example, for the normal-log-normal model: +$$ +\begin{aligned} +x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) +\end{aligned} +$$ -julia> true_posterior = posterior((pri, pri.Σ), MvNormal, observations) -DiagNormal( -dim: 2 -μ: [0.1746546592601148, 0.16457110079543008] -Σ: [0.009900990099009901 0.0; 0.0 0.009900990099009901] -) +A `LogDensityProblem` can be implemented as ``` -Comparing to our variational approximation, this looks pretty good! Worth noting that in this particular case the variational posterior seems to overestimate the variance. +using LogDensityProblems -To conclude, let's make a somewhat pretty picture: -```julia -julia> using Plots - -julia> p_samples = rand(true_posterior, 10_000); q_samples = rand(q, 10_000); - -julia> p1 = histogram(p_samples[1, :], label="p"); histogram!(q_samples[1, :], alpha=0.7, label="q") - -julia> title!(raw"$\mu_1$") +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end -julia> p2 = histogram(p_samples[2, :], label="p"); histogram!(q_samples[2, :], alpha=0.7, label="q") +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end -julia> title!(raw"$\mu_2$") +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end -julia> plot(p1, p2) +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end ``` -![Histogram](hist.png?raw=true) - -### Simple example: using Advanced.jl to directly minimize the KL-divergence between two distributions `p(z)` and `q(z)` -In VI we aim to approximate the true posterior `p(z ∣ x)` by some approximate variational posterior `q(z)` by maximizing the ELBO: - - ELBO(q) = 𝔼_q[log p(x, z) - log q(z)] - -Observe that we can express the ELBO as the negative KL-divergence between `p(x, ⋅)` and `q(⋅)`: - - ELBO(q) = - 𝔼_q[log (q(z) / p(x, z))] - = - KL(q(⋅) || p(x, ⋅)) - -So if we apply VI to something that isn't an actual posterior, i.e. there's no data involved and we write `p(z ∣ x) = p(z)`, we're really just minimizing the KL-divergence between the distributions. - -Therefore, we can try out `AdvancedVI.jl` real quick by applying using the interface to minimize the KL-divergence between two distributions: +Since the support of `x` is constrained to be $$\mathbb{R}_+$$, and inference is best done in the unconstrained space $$\mathbb{R}_+$$, we need to use a *bijector* to match support. +This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, 2015). ```julia -julia> using Distributions, DistributionsAD, AdvancedVI - -julia> # Target distribution - p = MvNormal(ones(2)) -ZeroMeanDiagNormal( -dim: 2 -μ: [0.0, 0.0] -Σ: [1.0 0.0; 0.0 1.0] -) +using Bijectors -julia> logπ(z) = logpdf(p, z) -logπ (generic function with 1 method) - -julia> # Make a choice of VI algorithm - advi = ADVI(10, 1000) -ADVI{AdvancedVI.ForwardDiffAD{40}}(10, 1000) -``` -Now there are two different ways of specifying the approximate posterior (and its family); the first is by providing a mapping from parameters to distribution `θ ↦ q(⋅∣θ)`: -```julia -julia> # Using a function z ↦ q(⋅∣z) - getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -getq (generic function with 1 method) - -julia> # Perform VI - q = vi(logπ, advi, getq, randn(4)) -┌ Info: [ADVI] Should only be seen once: optimizer created for θ -└ objectid(θ) = 0x5ddb564423896704 -[ADVI] Optimizing...100% Time: 0:00:01 -TuringDiagMvNormal{Array{Float64,1},Array{Float64,1}}(m=[-0.012691337868985757, -0.0004442434543332919], σ=[1.0334797673569802, 0.9957355128767893]) -``` -Or we can check the ELBO (which in this case since, as mentioned, doesn't involve data, is the negative KL-divergence): -```julia -julia> AdvancedVI.elbo(advi, q, logπ, 1000) # empirical estimate -0.08031049170093245 +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end ``` -It's worth noting that the actual value of the ELBO doesn't really tell us too much about the quality of fit. In this particular case, because we're *directly* minimizing the KL-divergence, we can only say something useful if we reach 0, in which case we have obtained the true distribution. -Let's just quickly check the mean-squared error between the `log p(z)` and `log q(z)` for a random set of samples from the target `p`: -```julia -julia> zs = rand(p, 100); +A simpler approach is to use `Turing`, where a `Turing.Model` can be automatically be converted into a `LogDensityProblem` and a corresponding `bijector` is automatically generated. -julia> mean(abs2, logpdf(q, zs) - logpdf(p, zs)) -0.0014889109427524852 +Let us instantiate a random normal-log-normal model. +```julia +using PDMats + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) ``` -That doesn't look too bad! - -## Implementing your own training loop -Sometimes it might be convenient to roll your own training loop rather than using `vi(...)`. Here's some psuedo-code for how one would do that when used together with Turing.jl: +ADVI can be used as follows: ```julia -using Turing, AdvancedVI, DiffResults -using Turing: Variational - -using ProgressMeter - -# Assuming you have an instance of a Turing model (`model`) - -# 1. Create log-joint needed for ELBO evaluation -logπ = Variational.make_logjoint(model) - -# 2. Define objective -variational_objective = Variational.ELBO() - -# 3. Optimizer -optimizer = Variational.DecayedADAGrad() - -# 4. VI-algorithm -alg = ADVI(10, 1000) - -# 5. Variational distribution -function getq(θ) - # ... -end - -# 6. [OPTIONAL] Implement convergence criterion -function hasconverged(args...) - # ... -end - -# 7. [OPTIONAL] Implement a callback for tracking stats -function callback(args...) - # ... -end - -# 8. Train -converged = false -step = 1 - -prog = ProgressMeter.Progress(num_steps, 1) - -diff_results = DiffResults.GradientResult(θ_init) - -while (step ≤ num_steps) && !converged - # 1. Compute gradient and objective value; results are stored in `diff_results` - AdvancedVI.grad!(variational_objective, alg, getq, model, diff_results) - - # 2. Extract gradient from `diff_result` - ∇ = DiffResults.gradient(diff_result) - - # 3. Apply optimizer, e.g. multiplying by step-size - Δ = apply!(optimizer, θ, ∇) - - # 4. Update parameters - @. θ = θ - Δ - - # 5. Do whatever analysis you want - callback(args...) - - # 6. Update - converged = hasconverged(...) # or something user-defined - step += 1 +using LinearAlgebra +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI + +b = Bijectors.bijector(model) +b⁻¹ = inverse(b) + +# ADVI objective +objective = AVI.ADVI(model, 10; b=b⁻¹) + +# Mean-field Gaussian variational family +d = LogDensityProblems.dimension(model) +μ = randn(d) +L = Diagonal(ones(d)) +q = AVI.VIMeanFieldGaussian(μ, L) + +# Run inference +n_max_iter = 10^4 +q, stats, _ = AVI.optimize( + objective, + q, + n_max_iter; + adbackend = ADTypes.AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +) - ProgressMeter.next!(prog) -end +# Evaluate final ELBO with 10^3 Monte Carlo samples +objective(q; n_samples=10^3) ``` ## References -- Jordan, Michael I., Zoubin Ghahramani, Tommi S. Jaakkola, and Lawrence K. Saul. "An introduction to variational methods for graphical models." Machine learning 37, no. 2 (1999): 183-233. -- Blei, David M., Alp Kucukelbir, and Jon D. McAuliffe. "Variational inference: A review for statisticians." Journal of the American statistical Association 112, no. 518 (2017): 859-877. - Kucukelbir, Alp, Rajesh Ranganath, Andrew Gelman, and David Blei. "Automatic variational inference in Stan." In Advances in Neural Information Processing Systems, pp. 568-576. 2015. -- Salimans, Tim, and David A. Knowles. "Fixed-form variational posterior approximation through stochastic linear regression." Bayesian Analysis 8, no. 4 (2013): 837-882. -- Beal, Matthew James. Variational algorithms for approximate Bayesian inference. 2003. From 28a35bcd0ce6bd4489915ae1cf37801db211b2ec Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:02:04 +0100 Subject: [PATCH 079/144] update make getting started example actually run Julia --- docs/Project.toml | 13 ++++- docs/src/started.md | 115 +++++++++++++++++++++++++++++++++----------- 2 files changed, 98 insertions(+), 30 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index c625d07f..182edd3e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,7 +1,18 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [compat] -Documenter = "0.26, 0.27" \ No newline at end of file +ADTypes = "0.1.6" +Bijectors = "0.13.6" +Documenter = "0.26, 0.27" +LogDensityProblems = "2.1.1" diff --git a/docs/src/started.md b/docs/src/started.md index 355e9350..fec60f1a 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -14,62 +14,119 @@ To use `AdvancedVI`, a user needs to select a `variational family`, `variational optimize ``` -## `ADVI` Example Using `Turing` +## `ADVI` Example +In this tutorial, we will work with a basic `normal-log-normal` model. +```math +\begin{aligned} +x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) +\end{aligned} +``` +ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. -In this tutorial, we'll use `Turing` to define a basic `normal-log-normal` model. -ADVI with log bijectors is able to infer this model exactly. -```julia -using Turing +Using the `LogDensityProblems` interface, we the model can be defined as follows: +```@example advi +using LogDensityProblems +using SimpleUnPack -μ_y, σ_y = 1.0, 1.0 -μ_z, Σ_z = [1.0, 2.0], [1.0 0.; 0. 2.0] +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end -Turing.@model function normallognormal() - y ~ LogNormal(μ_y, σ_y) - z ~ MvNormal(μ_z, Σ_z) +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end -model = normallognormal() + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end +``` +Let's now instantiate the model +```@example advi +using PDMats + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. -```julia +```@example advi using Bijectors -b = Bijectors.bijector(model) -b⁻¹ = inverse(b) +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +b = Bijectors.bijector(model); +b⁻¹ = inverse(b) ``` Let's now load `AdvancedVI`. Since ADVI relies on automatic differentiation (AD), hence the "AD" in "ADVI", we need to load an AD library, *before* loading `AdvancedVI`. Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. -```julia +```@example advi using Optimisers -using ForwardDiff +using ADTypes, ForwardDiff import AdvancedVI as AVI ``` We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. -```julia -prob = DynamicPPL.LogDensityFunction(model)] -n_montecaro = 10 -objective = AVI.ADVI(prob, b⁻¹, n_montecaro), +```@example advi +n_montecaro = 10; +objective = AVI.ADVI(model, n_montecaro; b = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. -```julia -d = LogDensityProblems.dimension(prob) -μ = randn(d) -L = Diagonal(ones(d)) +```@example advi +using LinearAlgebra + +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); q = AVI.VIMeanFieldGaussian(μ, L) ``` -It now remains to run inverence! -``` -n_max_iter = 10^4 -q, stats = AVI.optimize( +Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. +```@example advi +n_max_iter = 10^4 +q, stats, _ = AVI.optimize( + objective, q, n_max_iter; adbackend = AutoForwardDiff(), optimizer = Optimisers.Adam(1e-3) -) +); +``` + +The selected inference procedure stores per-iteration statistics into `stats`. +For instance, the ELBO can be ploted as follows: +```@example advi +using Plots + +t = [stat.iteration for stat ∈ stats] +y = [stat.elbo for stat ∈ stats] +plot(t[1:100:end], y[1:100:end]) +savefig("advi_example_elbo.svg"); nothing +``` +![](advi_example_elbo.svg) +Further information can be gathered by defining your own `callback!`. + +The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: +```@example advi +ELBO = objective(q; n_samples=10^4) ``` From 620b38e7d345c60d59c08174144f1349618ff60c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:02:16 +0100 Subject: [PATCH 080/144] fix remove Float32 tests for inference tests --- ext/AdvancedVIForwardDiffExt.jl | 2 +- test/advi_locscale.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index e6b03af2..5949bdf8 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -11,8 +11,8 @@ else using ..AdvancedVI: ADTypes, DiffResults end -# extract chunk size from AutoForwardDiff getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize + function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult ) where {T<:Real} diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index d4ef7aec..e4c81402 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -25,7 +25,7 @@ include("models/utils.jl") @testset "advi" begin @testset "locscale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float32, Float64], + realtype ∈ [Float64], # Currently only tested against Float64 (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, From fa533981d6c3208e008d04f35a18ec08728ca608 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 01:54:13 +0100 Subject: [PATCH 081/144] update version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 35650ae5..2092b0cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.2.4" +version = "0.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From e909f4106e919e2d834a4f73eac3ca929bd5b9dd Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 20:04:34 +0100 Subject: [PATCH 082/144] add documentation publishing url --- docs/make.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index ca21b5fd..5d371608 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,15 +7,16 @@ DocMeta.setdocmeta!( ) makedocs(; - sitename = "AdvancedVI.jl", modules = [AdvancedVI], + sitename = "AdvancedVI.jl", + repo = "https://github.com/TuringLang/AdvancedVI.jl/blob/{commit}{path}#{line}", format = Documenter.HTML(; prettyurls = get(ENV, "CI", nothing) == "true"), - pages = ["AdvancedVI" => "index.md", - "Getting Started" => "started.md", - "ELBO Maximization" => [ - "Automatic Differentiation VI" => "advi.md", - "Location Scale Family" => "locscale.md", - ]], + pages = ["AdvancedVI" => "index.md", + "Getting Started" => "started.md", + "ELBO Maximization" => [ + "Automatic Differentiation VI" => "advi.md", + "Location Scale Family" => "locscale.md", + ]], ) deploydocs(; repo="github.com/TuringLang/AdvancedVI.jl", push_preview=true) From 43f5b751abb963533cbb6835ca6c8315a53a41d2 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 20:17:04 +0100 Subject: [PATCH 083/144] fix wrong uuid for ForwardDiff --- src/AdvancedVI.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1677be62..c45d4997 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -106,7 +106,7 @@ function __init__() @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/AdvancedVIEnzymeExt.jl") end - @require ForwardDiff = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/AdvancedVIForwardDiffExt.jl") end @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin From 468d5ca3aa94f7c83287633beba23aa5d174ca88 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 17 Aug 2023 21:44:15 +0100 Subject: [PATCH 084/144] Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 158da963..26f6876f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,7 +20,7 @@ jobs: - windows-latest arch: - x64 - - x86 + # - x86 # Uncomment after https://github.com/JuliaTesting/ReTest.jl/pull/52 is merged exclude: - os: macOS-latest arch: x86 From c07a5118a237fd5eb3a478a88fdcefe06673b366 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 21:49:26 +0100 Subject: [PATCH 085/144] refactor use `sum` and `mean` instead of abusing `mapreduce` --- src/distributions/location_scale.jl | 4 ++-- src/objectives/elbo/advi.jl | 5 ++--- src/objectives/elbo/entropy.jl | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index e901e8de..3113c679 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -42,12 +42,12 @@ end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) + sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - mapreduce(zᵢ -> logpdf(dist, zᵢ), +, scale \ (z - location)) - first(logabsdet(scale)) + sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 8bc14bc9..67af4375 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -55,10 +55,9 @@ function (advi::ADVI)( q_η::ContinuousMultivariateDistribution, ηs ::AbstractMatrix ) - n_samples = size(ηs, 2) - 𝔼ℓ = mapreduce(+, eachcol(ηs)) do ηᵢ + 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) - (advi.ℓπ(zᵢ) + logdetjacᵢ) / n_samples + (advi.ℓπ(zᵢ) + logdetjacᵢ) end ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 0edc47f4..694eacef 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -23,9 +23,8 @@ The "sticking the landing" entropy estimator. struct StickingTheLandingEntropy <: MonteCarloEntropy end function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) - n_samples = size(ηs, 2) - mapreduce(+, eachcol(ηs)) do ηᵢ - -logpdf(q, ηᵢ) / n_samples + mean(eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) end end From 13a8a445af64690b61137f6791f4f11eb6130a2b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 17 Aug 2023 22:14:42 +0100 Subject: [PATCH 086/144] remove tests for `FullMonteCarlo` --- test/advi_locscale.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index e4c81402..962d3169 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -35,7 +35,6 @@ include("models/utils.jl") (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), - :ADVIFullMonteCarlo => (model, b, M) -> ADVI(model, M; b, entropy = FullMonteCarloEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), From aadf8d397aad300b6e5d502b8a90bd0f2724d778 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 01:31:58 +0100 Subject: [PATCH 087/144] add tests for the `optimize` interface --- test/advi_locscale.jl | 4 +-- test/optimize.jl | 84 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 test/optimize.jl diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 962d3169..bf51199f 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -38,8 +38,8 @@ include("models/utils.jl") ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - # :ReverseDiff => AutoReverseDiff(), - # :Zygote => AutoZygote(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), # :Enzyme => AutoEnzyme(), ) diff --git a/test/optimize.jl b/test/optimize.jl new file mode 100644 index 00000000..3ece467f --- /dev/null +++ b/test/optimize.jl @@ -0,0 +1,84 @@ + +using ReTest +using Bijectors +using LogDensityProblems +using Optimisers +using Distributions +using PDMats +using LinearAlgebra +using SimpleUnPack: @unpack + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("models/normallognormal.jl") +include("models/utils.jl") + +@testset "optimize" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + T = 1000 + modelstats = normallognormal_meanfield(Float64; rng) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + # Global Test Configurations + b⁻¹ = Bijectors.bijector(model) |> inverse + μ₀ = zeros(Float64, n_dims) + L₀ = ones(Float64, n_dims) |> Diagonal + q₀ = VIMeanFieldGaussian(μ₀, L₀) + obj = ADVI(model, 10; b=b⁻¹) + + adbackend = AutoForwardDiff() + optimizer = Optimisers.Adam(1e-2) + + rng = Philox4x(UInt64, seed, 8) + q_ref, stats_ref, _ = optimize( + obj, q₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + ) + λ_ref, _ = Optimisers.destructure(q_ref) + + @testset "restructure" begin + λ₀, re = Optimisers.destructure(q₀) + + rng = Philox4x(UInt64, seed, 8) + λ, stats, _ = optimize( + obj, re, λ₀, T; + optimizer, + show_progress = false, + rng, + adbackend, + ) + @test λ == λ_ref + @test stats == stats_ref + end + + @testset "callback" begin + rng = Philox4x(UInt64, seed, 8) + test_values = rand(rng, T) + + callback!(; stat, est_state, restructure, λ) = begin + (test_value = test_values[stat.iteration],) + end + + rng = Philox4x(UInt64, seed, 8) + _, stats, _ = optimize( + obj, q₀, T; + show_progress = false, + rng, + adbackend, + callback! + ) + @test [stat.test_value for stat ∈ stats] == test_values + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 68225fd9..6bd3bc49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,11 +8,13 @@ using Random123 using Statistics using Distributions using LinearAlgebra + using AdvancedVI include("ad.jl") include("distributions.jl") include("advi_locscale.jl") +include("optimize.jl") @main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) From 8c4e13db72524ad31bf6306219436d3b78320237 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 01:33:05 +0100 Subject: [PATCH 088/144] fix turn off Zygote tests for now --- test/advi_locscale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index bf51199f..e8b4be03 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -39,7 +39,7 @@ include("models/utils.jl") (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), + # :Zygote => AutoZygote(), # :Enzyme => AutoEnzyme(), ) From 0b708e6297d781722a582058d42f7e0917cf49bd Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 03:09:11 +0100 Subject: [PATCH 089/144] remove unused function --- src/objectives/elbo/entropy.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 694eacef..022ed4f6 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,8 +5,6 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end -skip_entropy_gradient(::ClosedFormEntropy) = false - abstract type MonteCarloEntropy <: AbstractEntropyEstimator end struct FullMonteCarloEntropy <: MonteCarloEntropy end From be61acd46d457206cbd07386377958d5afb178e3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 03:51:34 +0100 Subject: [PATCH 090/144] refactor change bijector field name, simplify STL estimator --- Project.toml | 2 ++ src/AdvancedVI.jl | 4 +-- src/objectives/elbo/advi.jl | 46 +++++++--------------------------- src/objectives/elbo/entropy.jl | 15 ++++++----- test/advi_locscale.jl | 4 +-- test/optimize.jl | 2 +- 6 files changed, 25 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index 2092b0cb..e099308a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -37,6 +38,7 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" Bijectors = "0.11, 0.12, 0.13" +ChainRules = "1.53.0" DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index c45d4997..cca220f1 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -23,7 +23,7 @@ using LogDensityProblems using ADTypes, DiffResults using ADTypes: AbstractADType - +using ChainRules: @ignore_derivatives using FillArrays using PDMats @@ -74,7 +74,7 @@ export ADVI, ClosedFormEntropy, StickingTheLandingEntropy, - FullMonteCarloEntropy + MonteCarloEntropy # Variational Families diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 67af4375..788449d1 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -11,7 +11,7 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) # Keyword Arguments - `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) - `cv`: A control variate. -- `b`: A bijector mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) +- `invbij`: A bijective mapping the support of the base distribution to that of `prob`. (Default: `Bijectors.identity`.) # Requirements - ``q_{\\lambda}`` implements `rand`. @@ -23,7 +23,7 @@ struct ADVI{Tlogπ, B, EntropyEst <: AbstractEntropyEstimator, ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective ℓπ::Tlogπ - b::B + invbij::B entropy::EntropyEst cv::ControlVar n_samples::Int @@ -31,7 +31,7 @@ struct ADVI{Tlogπ, B, function ADVI(prob, n_samples::Int; entropy::AbstractEntropyEstimator = ClosedFormEntropy(), cv::Union{<:AbstractControlVariate, Nothing} = nothing, - b = Bijectors.identity) + invbij = Bijectors.identity) cap = LogDensityProblems.capabilities(prob) if cap === nothing throw( @@ -41,7 +41,7 @@ struct ADVI{Tlogπ, B, ) end ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(b), typeof(entropy), typeof(cv)}(ℓπ, b, entropy, cv, n_samples) + new{typeof(ℓπ), typeof(invbij), typeof(entropy), typeof(cv)}(ℓπ, invbij, entropy, cv, n_samples) end end @@ -56,7 +56,7 @@ function (advi::ADVI)( ηs ::AbstractMatrix ) 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ - zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.b, ηᵢ) + zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) (advi.ℓπ(zᵢ) + logdetjacᵢ) end ℍ = advi.entropy(q_η, ηs) @@ -86,50 +86,22 @@ function (advi::ADVI)( advi(rng, q_η, ηs) end -function estimate_advi_gradient_maybe_stl!( - rng::AbstractRNG, - adbackend::AbstractADType, - advi::ADVI{P, B, StickingTheLandingEntropy, CV}, - λ::Vector{<:Real}, - restructure, - out::DiffResults.MutableDiffResult -) where {P, B, CV} - q_η_stop = restructure(λ) - f(λ′) = begin - q_η = restructure(λ′) - ηs = rand(rng, q_η, advi.n_samples) - -advi(rng, q_η_stop, ηs) - end - value_and_gradient!(adbackend, f, λ, out) -end - -function estimate_advi_gradient_maybe_stl!( +function estimate_gradient( rng::AbstractRNG, adbackend::AbstractADType, - advi::ADVI{P, B, <:Union{ClosedFormEntropy, FullMonteCarloEntropy}, CV}, + advi::ADVI, + est_state, λ::Vector{<:Real}, restructure, out::DiffResults.MutableDiffResult -) where {P, B, CV} +) f(λ′) = begin q_η = restructure(λ′) ηs = rand(rng, q_η, advi.n_samples) -advi(rng, q_η, ηs) end value_and_gradient!(adbackend, f, λ, out) -end -function estimate_gradient( - rng::AbstractRNG, - adbackend::AbstractADType, - advi::ADVI, - est_state, - λ::Vector{<:Real}, - restructure, - out::DiffResults.MutableDiffResult -) - estimate_advi_gradient_maybe_stl!( - rng, adbackend, advi, λ, restructure, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 022ed4f6..97ccda29 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -5,9 +5,13 @@ function (::ClosedFormEntropy)(q, ::AbstractMatrix) entropy(q) end -abstract type MonteCarloEntropy <: AbstractEntropyEstimator end +struct MonteCarloEntropy <: AbstractEntropyEstimator end -struct FullMonteCarloEntropy <: MonteCarloEntropy end +function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) + mean(eachcol(ηs)) do ηᵢ + -logpdf(q, ηᵢ) + end +end """ StickingTheLandingEntropy() @@ -18,11 +22,10 @@ The "sticking the landing" entropy estimator. - `q` implements `logpdf`. - `logpdf(q, η)` must be differentiable by the selected AD framework. """ -struct StickingTheLandingEntropy <: MonteCarloEntropy end +struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -function (::MonteCarloEntropy)(q, ηs::AbstractMatrix) - mean(eachcol(ηs)) do ηᵢ +function (::StickingTheLandingEntropy)(q, ηs::AbstractMatrix) + @ignore_derivatives mean(eachcol(ηs)) do ηᵢ -logpdf(q, ηᵢ) end end - diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index e8b4be03..71cf22d5 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -33,8 +33,8 @@ include("models/utils.jl") :NormalFullRank => normal_fullrank, ), (objname, objective) ∈ Dict( - :ADVIClosedFormEntropy => (model, b, M) -> ADVI(model, M; b), - :ADVIStickingTheLanding => (model, b, M) -> ADVI(model, M; b, entropy = StickingTheLandingEntropy()), + :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), + :ADVIStickingTheLanding => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹, entropy = StickingTheLandingEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), diff --git a/test/optimize.jl b/test/optimize.jl index 3ece467f..d514d236 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -33,7 +33,7 @@ include("models/utils.jl") μ₀ = zeros(Float64, n_dims) L₀ = ones(Float64, n_dims) |> Diagonal q₀ = VIMeanFieldGaussian(μ₀, L₀) - obj = ADVI(model, 10; b=b⁻¹) + obj = ADVI(model, 10; invbij=b⁻¹) adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) From fb519a501585fd279a62bce331ea81b19627ba06 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 03:51:59 +0100 Subject: [PATCH 091/144] update documentation --- docs/src/advi.md | 177 +++++++++++++++++++++++++++++++++++++++++--- docs/src/started.md | 8 +- 2 files changed, 170 insertions(+), 15 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 37b3541b..3719c89e 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -66,34 +66,187 @@ ADVI The STL control variate was proposed by Roeder *et al.* (2017). By slightly modifying the differentiation path, it implicitly forms a control variate of the form of ```math -\mathrm{CV}_{\mathrm{STL}}\left(z\right) \triangleq \mathbb{H}\left(q_{\lambda}\right) + \log q_{\lambda}\left(z\right), +\begin{aligned} + \mathrm{CV}_{\mathrm{STL}}\left(z\right) + &\triangleq + \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) \\ + &= + -\nabla_{\lambda} \mathbb{E}_{z \sim q_{\nu}} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) + \nabla_{\lambda} \log q_{\nu}\left(z_{\lambda}\left(u\right)\right) +\end{aligned} ``` -which has a mean of zero. +where ``\nu = \lambda`` is set to avoid differentiating through the density of ``q_{\lambda}``. +We can see that this vector-valued function has a mean of zero and is therefore a valid control variate. Adding this to the closed-form entropy ELBO estimator yields the STL estimator: ```math \begin{aligned} - \widehat{\mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) - &\triangleq \mathbb{E}\left[ \log \pi \left(z\right) \right] - \log q_{\lambda} \left(z\right) \\ - &= \mathbb{E}\left[ \log \pi\left(z\right) \right] - + \mathbb{H}\left(q_{\lambda}\right) - \mathrm{CV}_{\mathrm{STL}}\left(z\right) \\ - &= \widehat{\mathrm{ELBO}}\left(\lambda\right) - - \mathrm{CV}_{\mathrm{STL}}\left(z\right), + \widehat{\nabla \mathrm{ELBO}}_{\mathrm{STL}}\left(\lambda\right) + &\triangleq \mathbb{E}_{u \sim \varphi}\left[ + \nabla_{\lambda} \log \pi \left(z_{\lambda}\left(u\right)\right) + - + \nabla_{\lambda} \log q_{\nu} \left(z_{\lambda}\left(u\right)\right) + \right] + \\ + &= + \mathbb{E}\left[ \nabla_{\lambda} \log \pi\left(z_{\lambda}\left(u\right)\right) \right] + + + \nabla_{\lambda} \mathbb{H}\left(q_{\lambda}\right) + - + \mathrm{CV}_{\mathrm{STL}}\left(z\right) + \\ + &= + \widehat{\nabla \mathrm{ELBO}}\left(\lambda\right) + - + \mathrm{CV}_{\mathrm{STL}}\left(z\right), \end{aligned} ``` -which has the same expectation, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. +which has the same expectation as the original ADVI estimator, but lower variance when ``\pi \approx q_{\lambda}``, and higher variance when ``\pi \not\approx q_{\lambda}``. The conditions for which the STL estimator results in lower variance is still an active subject for research. +The main downside of the STL estimator is that it needs to evaluate and differentiate the log density of ``q_{\lambda}`` in every iteration. +Depending on the variational family, this might be computationally inefficient or even numerically unstable. +For example, if ``q_{\lambda}`` is a Gaussian with a full-rank covariance, a back-substitution must be performed at every step, making the per-iteration complexity ``\mathcal{O}(d^3)`` and reducing numerical stability. + + The STL control variate can be used by changing the entropy estimator using the following object: ```@docs StickingTheLandingEntropy ``` -For example: -```julia -ADVI(prob, n_samples; entropy = StickingTheLandingEntropy(), b = bijector) +```@setup stl +using LogDensityProblems +using SimpleUnPack +using PDMats +using Bijectors +using LinearAlgebra +using Plots + +using Optimisers +using ADTypes, ForwardDiff +import AdvancedVI as AVI + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +n_dims = 10 +μ_x = randn() +σ_x = exp.(randn()) +μ_y = randn(n_dims) +σ_y = exp.(randn(n_dims)) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); + +d = LogDensityProblems.dimension(model); +μ = randn(d); +L = Diagonal(ones(d)); +q0 = AVI.VIMeanFieldGaussian(μ, L) + +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end ``` +Let us come back to the example in [Getting Started](@ref getting_started), where a `LogDensityProblem` is given as `model`. +In this example, the true posterior is contained within the variational family. +This setting is known as "perfect variational family specification." +In this case, the STL estimator is able to converge exponentially fast to the true solution. + +Recall that the original ADVI objective with a closed-form entropy (CFE) is given as follows: +```@example stl +n_montecarlo = 1; +b = Bijectors.bijector(model); +b⁻¹ = inverse(b) + +cfe = AVI.ADVI(model, n_montecarlo; invbij = b⁻¹) +``` +The STL estimator can instead be created as follows: +```@example stl +stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), invbij = b⁻¹); +``` + +```@setup stl +n_max_iter = 10^4 + +idx = [1] +callback!(; stat, est_state, restructure, λ) = begin + if mod(idx[1], 100) == 1 + idx[:] .+= 1 + (elbo_accurate = cfe(restructure(λ); n_samples=10^4),) + else + idx[:] .+= 1 + NamedTuple() + end +end + +_, stats_cfe, _ = AVI.optimize( + cfe, + q0, + n_max_iter; + show_progress = false, + callback! = callback!, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +idx[:] .= 1 +_, stats_stl, _ = AVI.optimize( + stl, + q0, + n_max_iter; + show_progress = false, + callback! = callback!, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +fmc = AVI.ADVI(model, n_montecarlo; entropy = AVI.MonteCarloEntropy(), invbij = b⁻¹) +idx[:] .= 1 +_, stats_fmc, _ = AVI.optimize( + fmc, + q0, + n_max_iter; + show_progress = false, + callback! = callback!, + adbackend = AutoForwardDiff(), + optimizer = Optimisers.Adam(1e-3) +); + +t = [stat.iteration for stat ∈ stats_cfe[1:100:end]] +y_cfe = [stat.elbo_accurate for stat ∈ stats_cfe[1:100:end]] +y_stl = [stat.elbo_accurate for stat ∈ stats_stl[1:100:end]] +y_fmc = [stat.elbo_accurate for stat ∈ stats_fmc[1:100:end]] +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +plot!(t, y_fmc, label="ADVI FMC", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +savefig("advi_stl_elbo.svg") +nothing +``` +![](advi_stl_elbo.svg) + +We can see that the noise of the STL estimator converges to a more accurate solution compared to the CFE estimator. + + ## References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. diff --git a/docs/src/started.md b/docs/src/started.md index fec60f1a..b89a140a 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -90,7 +90,7 @@ We now need to select 1. a variational objective, and 2. a variational family. Here, we will use the [`ADVI` objective](@ref advi), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. ```@example advi n_montecaro = 10; -objective = AVI.ADVI(model, n_montecaro; b = b⁻¹) +objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi @@ -120,10 +120,12 @@ using Plots t = [stat.iteration for stat ∈ stats] y = [stat.elbo for stat ∈ stats] -plot(t[1:100:end], y[1:100:end]) -savefig("advi_example_elbo.svg"); nothing +plot(t, y, label="ADVI", xlabel="Iteration", ylabel="ELBO") +savefig("advi_example_elbo.svg") +nothing ``` ![](advi_example_elbo.svg) + Further information can be gathered by defining your own `callback!`. The final ELBO can be estimated by calling the objective directly with a different number of Monte Carlo samples as follows: From 8682fd92d7746e3f6741bbcb2f2029b12653ba72 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 04:00:17 +0100 Subject: [PATCH 092/144] update STL documentation --- docs/src/advi.md | 42 ++++++++---------------------------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 3719c89e..0d5b9568 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -63,6 +63,7 @@ ADVI ``` ## The `StickingTheLanding` Control Variate + The STL control variate was proposed by Roeder *et al.* (2017). By slightly modifying the differentiation path, it implicitly forms a control variate of the form of ```math @@ -188,63 +189,36 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i ```@setup stl n_max_iter = 10^4 -idx = [1] -callback!(; stat, est_state, restructure, λ) = begin - if mod(idx[1], 100) == 1 - idx[:] .+= 1 - (elbo_accurate = cfe(restructure(λ); n_samples=10^4),) - else - idx[:] .+= 1 - NamedTuple() - end -end - _, stats_cfe, _ = AVI.optimize( cfe, q0, n_max_iter; show_progress = false, - callback! = callback!, adbackend = AutoForwardDiff(), optimizer = Optimisers.Adam(1e-3) ); -idx[:] .= 1 _, stats_stl, _ = AVI.optimize( stl, q0, n_max_iter; show_progress = false, - callback! = callback!, - adbackend = AutoForwardDiff(), - optimizer = Optimisers.Adam(1e-3) -); - -fmc = AVI.ADVI(model, n_montecarlo; entropy = AVI.MonteCarloEntropy(), invbij = b⁻¹) -idx[:] .= 1 -_, stats_fmc, _ = AVI.optimize( - fmc, - q0, - n_max_iter; - show_progress = false, - callback! = callback!, adbackend = AutoForwardDiff(), optimizer = Optimisers.Adam(1e-3) ); -t = [stat.iteration for stat ∈ stats_cfe[1:100:end]] -y_cfe = [stat.elbo_accurate for stat ∈ stats_cfe[1:100:end]] -y_stl = [stat.elbo_accurate for stat ∈ stats_stl[1:100:end]] -y_fmc = [stat.elbo_accurate for stat ∈ stats_fmc[1:100:end]] -plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) -plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) -plot!(t, y_fmc, label="ADVI FMC", xlabel="Iteration", ylabel="ELBO", ylims=[-5, 1]) +t = [stat.iteration for stat ∈ stats_cfe] +y_cfe = [stat.elbo for stat ∈ stats_cfe] +y_stl = [stat.elbo for stat ∈ stats_stl] +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO") +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO") savefig("advi_stl_elbo.svg") nothing ``` ![](advi_stl_elbo.svg) -We can see that the noise of the STL estimator converges to a more accurate solution compared to the CFE estimator. +We can see that the noise of the STL estimator becomes smaller as VI converges. +However, the difference in speed of convergence may not always be significant. ## References From 9a16ee109a8b095e36d700389b18a35dc1355c2c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 04:01:48 +0100 Subject: [PATCH 093/144] update STL documentation --- docs/src/advi.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 0d5b9568..afb780cb 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -218,7 +218,7 @@ nothing ![](advi_stl_elbo.svg) We can see that the noise of the STL estimator becomes smaller as VI converges. -However, the difference in speed of convergence may not always be significant. +However, the speed of convergence may not always be significantly different. ## References From fc74afaef98e8c31ed04c55abdce20d25a644e4d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 18 Aug 2023 04:03:33 +0100 Subject: [PATCH 094/144] update location scale documentation --- docs/src/locscale.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/locscale.md b/docs/src/locscale.md index a4bc2dc1..63ff5cb4 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -10,6 +10,7 @@ z \stackrel{d}{=} C u + m;\quad u \sim \varphi where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. ``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. + The probability density is given by ```math q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) @@ -19,6 +20,8 @@ and the entropy is given as \mathcal{H}(q_{\lambda}) = \mathcal{H}(\varphi) + \log |C|, ``` where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. +Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. +The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. ## Constructors From 4be30a1a44c70b4e9356768fd2d8ac662e7bc461 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 20 Aug 2023 00:10:48 +0100 Subject: [PATCH 095/144] fix README --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e8718e7c..c43748e5 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,14 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije `AdvancedVI` basically expects a `LogDensityProblem`. For example, for the normal-log-normal model: $$ -\begin{aligned} +\begin{align*} x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) -\end{aligned} -$$ +\end{align*} +$$ A `LogDensityProblem` can be implemented as -``` +```julia using LogDensityProblems struct NormalLogNormal{MX,SX,MY,SY} From c58309dbaea25c986074b69a02f5bc6035dfcde8 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 20 Aug 2023 00:12:15 +0100 Subject: [PATCH 096/144] fix math in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c43748e5..8def2d98 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije `AdvancedVI` basically expects a `LogDensityProblem`. For example, for the normal-log-normal model: + $$ \begin{align*} x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ From 5b5bd3e9c3f4e90ac0d34f789b17c43c199ebd7d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 20 Aug 2023 03:08:20 +0100 Subject: [PATCH 097/144] add gradient to arguments of callback!, remove `gradient_norm` info --- src/objectives/elbo/advi.jl | 2 +- src/optimize.jl | 8 ++++---- test/optimize.jl | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index 788449d1..d8719fa7 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -57,7 +57,7 @@ function (advi::ADVI)( ) 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) - (advi.ℓπ(zᵢ) + logdetjacᵢ) + advi.ℓπ(zᵢ) + logdetjacᵢ end ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ diff --git a/src/optimize.jl b/src/optimize.jl index 93e6f754..43b06689 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -26,7 +26,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie # Arguments - `objective`: Variational Objective. - `λ₀`: Initial value of the variational parameters. -- `restructure`: Function that reconstructs the variational approximation from the flattened parameters. +- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. - `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. - `n_max_iter`: Maximum number of iterations. @@ -35,7 +35,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) +- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) # Returns @@ -76,11 +76,11 @@ function optimize( g = DiffResults.gradient(grad_buf) opt_state, λ = Optimisers.update!(opt_state, λ, g) - stat′ = (iteration=t, gradient_norm=norm(g)) + stat′ = (iteration = t,) stat = merge(stat, stat′) if !isnothing(callback!) - stat′ = callback!(; est_state, stat, restructure, λ) + stat′ = callback!(; est_state, stat, λ, g, restructure) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end diff --git a/test/optimize.jl b/test/optimize.jl index d514d236..920a3070 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -67,7 +67,7 @@ include("models/utils.jl") rng = Philox4x(UInt64, seed, 8) test_values = rand(rng, T) - callback!(; stat, est_state, restructure, λ) = begin + callback!(; stat, est_state, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end From 967021d2a1aa827d9dedda00c2b3eae39638986e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:43:43 +0100 Subject: [PATCH 098/144] fix math in README.md Co-authored-by: David Widmann --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8def2d98..83c2e8bc 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,10 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije For example, for the normal-log-normal model: $$ -\begin{align*} -x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) -\end{align*} +\begin{aligned} +x &\sim \operatorname{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) +\end{aligned} $$ A `LogDensityProblem` can be implemented as From 4dab522ff2583f7a622f7c6d35f829f8daf37cf2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:44:16 +0100 Subject: [PATCH 099/144] fix type constraint in `ZygoteExt` Co-authored-by: David Widmann --- ext/AdvancedVIZygoteExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index b447d071..c3d891bb 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -12,10 +12,10 @@ else end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) y, back = Zygote.pullback(f, θ) - ∇θ = back(one(T)) + ∇θ = back(one(y)) DiffResults.value!(out, y) DiffResults.gradient!(out, first(∇θ)) return out From 8ab2f19d208d82d720f462107b4949a16bfa3513 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:44:58 +0100 Subject: [PATCH 100/144] fix import of `Random` Co-authored-by: David Widmann --- src/AdvancedVI.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index cca220f1..a314e992 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -4,7 +4,7 @@ module AdvancedVI using SimpleUnPack: @unpack, @pack! using Accessors -import Random: AbstractRNG, default_rng +using Random: AbstractRNG, default_rng using Distributions import Distributions: logpdf, _logpdf, rand, _rand!, _rand!, From 83dec9fdc25226ed2dff13cc576981f09351a229 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:46:08 +0100 Subject: [PATCH 101/144] refactor `__init__()` Co-authored-by: David Widmann --- src/AdvancedVI.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a314e992..348a6a30 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -101,8 +101,8 @@ if !isdefined(Base, :get_extension) # check whether :get_extension is defined in using Requires end -function __init__() - @static if !isdefined(Base, :get_extension) +@static if !isdefined(Base, :get_extension) + function __init__() @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin include("../ext/AdvancedVIEnzymeExt.jl") end From a3e563cd43d937602e01f36e87247068f2a0b4ab Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:47:08 +0100 Subject: [PATCH 102/144] fix type constraint in definition of `value_and_gradient!` Co-authored-by: David Widmann --- src/AdvancedVI.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 348a6a30..42cd0dc5 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -39,9 +39,9 @@ const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) value_and_gradient!( ad::ADTypes.AbstractADType, f, - θ::AbstractVector{T}, + θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult - ) where {T<:Real} + ) Compute the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. The result is stored in `out`. From 5553bb950840ea9b8c6aba7794f52d58d3fce910 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:52:56 +0100 Subject: [PATCH 103/144] refactor `ZygoteExt`; use `only` instead of `first` Co-authored-by: David Widmann --- ext/AdvancedVIZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index c3d891bb..7b8f8817 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -17,7 +17,7 @@ function AdvancedVI.value_and_gradient!( y, back = Zygote.pullback(f, θ) ∇θ = back(one(y)) DiffResults.value!(out, y) - DiffResults.gradient!(out, first(∇θ)) + DiffResults.gradient!(out, only(∇θ)) return out end From 79b455746860f7957a7703d8d99fcbd79e613409 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 23:53:38 +0100 Subject: [PATCH 104/144] refactor type constraint in `ReverseDiffExt` Co-authored-by: David Widmann --- ext/AdvancedVIReverseDiffExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index fd7fbaab..520cd9ff 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,8 +13,8 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) tp = ReverseDiff.GradientTape(f, θ) ReverseDiff.gradient!(out, tp, θ) return out From 656b44b03f86ea83cba1d8de3953db956ffbe0ab Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 21 Aug 2023 23:56:28 +0100 Subject: [PATCH 105/144] refactor remove outdated debug mode macro --- src/AdvancedVI.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 42cd0dc5..ae0dc684 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -32,8 +32,6 @@ using Bijectors using StatsBase using StatsBase: entropy -const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) - # derivatives """ value_and_gradient!( From c7940636a8e08f5a97740f9872c4cffb4e6bed4d Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 00:10:00 +0100 Subject: [PATCH 106/144] fix remove outdated DEBUG mechanism --- src/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimize.jl b/src/optimize.jl index 43b06689..57ee8030 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -84,7 +84,7 @@ function optimize( stat = !isnothing(stat′) ? merge(stat′, stat) : stat end - AdvancedVI.DEBUG && @debug "Step $t" stat... + @debug "Iteration $t" stat... pm_next!(prog, stat) push!(stats, stat) From 0c5cc1ce8eacc3451bf360bb3c1b0301415242d4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 00:13:43 +0100 Subject: [PATCH 107/144] fix LaTeX in README: `operatorname` is currently broken --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 83c2e8bc..b3538ccf 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ For example, for the normal-log-normal model: $$ \begin{aligned} -x &\sim \operatorname{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +x &\sim \mathrm{Log\text{-}Normal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} $$ From 29d7d27ca227413275174e12f9258b13b8276fd0 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:04:43 +0100 Subject: [PATCH 108/144] remove `SimpleUnPack` dependency --- Project.toml | 1 - docs/Project.toml | 1 - docs/src/advi.md | 9 +++------ docs/src/started.md | 11 ++++------- src/AdvancedVI.jl | 1 - src/distributions/location_scale.jl | 14 +++++++------- test/Project.toml | 1 - test/advi_locscale.jl | 3 +-- test/models/normal.jl | 2 +- test/models/normallognormal.jl | 7 ++++--- test/optimize.jl | 1 - 11 files changed, 20 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index e099308a..29cc559f 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/docs/Project.toml b/docs/Project.toml index 182edd3e..1f4ba59f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,7 +9,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [compat] ADTypes = "0.1.6" diff --git a/docs/src/advi.md b/docs/src/advi.md index afb780cb..88c11fee 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -116,7 +116,6 @@ StickingTheLandingEntropy ```@setup stl using LogDensityProblems -using SimpleUnPack using PDMats using Bijectors using LinearAlgebra @@ -134,7 +133,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -151,17 +150,15 @@ n_dims = 10 σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); q0 = AVI.VIMeanFieldGaussian(μ, L) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); - function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) diff --git a/docs/src/started.md b/docs/src/started.md index b89a140a..4a1d26ec 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -27,7 +27,6 @@ ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. Using the `LogDensityProblems` interface, we the model can be defined as follows: ```@example advi using LogDensityProblems -using SimpleUnPack struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -37,7 +36,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -51,14 +50,14 @@ end ``` Let's now instantiate the model ```@example advi -using PDMats +using LinearAlgebra n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. @@ -67,7 +66,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat using Bijectors function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -94,8 +93,6 @@ objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi -using LinearAlgebra - d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index ae0dc684..5d0c3f8d 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,7 +1,6 @@ module AdvancedVI -using SimpleUnPack: @unpack, @pack! using Accessors using Random: AbstractRNG, default_rng diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 3113c679..73be42b9 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -35,42 +35,42 @@ Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) function StatsBase.entropy(q::VILocationScale) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) n_dims*entropy(dist) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale*rand(dist, n_dims) + location end function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) - @unpack location, scale, dist = q + (; location, scale, dist) = q n_dims = length(location) scale*rand(rng, dist, n_dims, num_samples) .+ location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q rand!(rng, dist, x) x .= scale*x return x += location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) - @unpack location, scale, dist = q + (; location, scale, dist) = q rand!(rng, dist, x) x *= scale return x += location diff --git a/test/Project.toml b/test/Project.toml index 2f38c88f..277b73c8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 71cf22d5..c6aee68b 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -8,7 +8,6 @@ using Optimisers using Distributions using PDMats using LinearAlgebra -using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M @@ -48,7 +47,7 @@ include("models/utils.jl") T = 10000 modelstats = modelconstr(realtype; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/models/normal.jl b/test/models/normal.jl index f60ad5f3..1dfa653c 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -5,7 +5,7 @@ struct TestMvNormal{M,S} end function LogDensityProblems.logdensity(model::TestMvNormal, θ) - @unpack μ, Σ = model + (; μ, Σ) = model logpdf(MvNormal(μ, Σ), θ) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index cab73cce..49da5bf6 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) end function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -56,7 +56,8 @@ function normallognormal_meanfield(realtype; rng = default_rng()) μ_y = randn(rng, realtype, n_dims) σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + #model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) μ = vcat(μ_x, μ_y) L = vcat(σ_x, σ_y) |> Diagonal diff --git a/test/optimize.jl b/test/optimize.jl index 920a3070..c96fa6cd 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -6,7 +6,6 @@ using Optimisers using Distributions using PDMats using LinearAlgebra -using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M From 75eef445a5daea37d79106851b26af292de2542b Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:05:08 +0100 Subject: [PATCH 109/144] fix LaTeX in docs and README --- README.md | 2 +- docs/src/started.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b3538ccf..d9638bfd 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ For example, for the normal-log-normal model: $$ \begin{aligned} -x &\sim \mathrm{Log\text{-}Normal}\left(\mu_x, \sigma_x^2\right) \\ +x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} $$ diff --git a/docs/src/started.md b/docs/src/started.md index 4a1d26ec..a129fc46 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -18,8 +18,8 @@ optimize In this tutorial, we will work with a basic `normal-log-normal` model. ```math \begin{aligned} -x &\sim \mathsf{log\text{-}normal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathsf{normal}\left(\mu_y, \sigma_y^2\right) +x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) \end{aligned} ``` ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. From 40574f46864513ced4051867159e0660b2f4b061 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:10:29 +0100 Subject: [PATCH 110/144] add warning about forward-mode AD when using `LocationScale` --- docs/src/locscale.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/src/locscale.md b/docs/src/locscale.md index 63ff5cb4..8f14a9ad 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -23,6 +23,9 @@ where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. +!!! warning + `LocationScale` and its specializations such as `VIFullRankGaussian` and `VIMeanFieldGaussian` are inefficient with forward-mode differentiation packages like `ForwardDiff`. Especially, they scale poorly with the number of dimensions. Please use reverse-mode differentation packages such as `ReverseDiff` and `Zygote`. + ## Constructors ```@docs From 8738256bd44fc38dd49807a69f70da41fa50448c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:14:04 +0100 Subject: [PATCH 111/144] fix documentation --- README.md | 7 +++---- docs/src/started.md | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d9638bfd..07407fa9 100644 --- a/README.md +++ b/README.md @@ -8,17 +8,16 @@ For example, `Turing` combines `Turing.Model`s with `AdvancedVI.ADVI` and [`Bije ## Examples -`AdvancedVI` basically expects a `LogDensityProblem`. +`AdvancedVI` expects a `LogDensityProblem`. For example, for the normal-log-normal model: $$ \begin{aligned} x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ -y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right) +y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), \end{aligned} $$ - -A `LogDensityProblem` can be implemented as +a `LogDensityProblem` can be implemented as ```julia using LogDensityProblems diff --git a/docs/src/started.md b/docs/src/started.md index a129fc46..b07a5bd3 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -15,7 +15,7 @@ optimize ``` ## `ADVI` Example -In this tutorial, we will work with a basic `normal-log-normal` model. +In this tutorial, we will work with a `normal-log-normal` model. ```math \begin{aligned} x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ From 817374403e58cb11e4e0e3aaee045c350d5bdfdc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 01:18:52 +0100 Subject: [PATCH 112/144] fix remove reamining use of `@unpack` --- test/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize.jl b/test/optimize.jl index c96fa6cd..96930495 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -25,7 +25,7 @@ include("models/utils.jl") T = 1000 modelstats = normallognormal_meanfield(Float64; rng) - @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats # Global Test Configurations b⁻¹ = Bijectors.bijector(model) |> inverse From e0548aecdc3468aa836d58b55aa3be60124d4782 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 22:22:02 -0400 Subject: [PATCH 113/144] Revert "remove `SimpleUnPack` dependency" This reverts commit 29d7d27ca227413275174e12f9258b13b8276fd0. --- Project.toml | 1 + docs/Project.toml | 1 + docs/src/advi.md | 9 ++++++--- docs/src/started.md | 11 +++++++---- src/AdvancedVI.jl | 1 + src/distributions/location_scale.jl | 14 +++++++------- test/Project.toml | 1 + test/advi_locscale.jl | 3 ++- test/models/normal.jl | 2 +- test/models/normallognormal.jl | 7 +++---- test/optimize.jl | 1 + 11 files changed, 31 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 29cc559f..e099308a 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/docs/Project.toml b/docs/Project.toml index 1f4ba59f..182edd3e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,6 +9,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [compat] ADTypes = "0.1.6" diff --git a/docs/src/advi.md b/docs/src/advi.md index 88c11fee..afb780cb 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -116,6 +116,7 @@ StickingTheLandingEntropy ```@setup stl using LogDensityProblems +using SimpleUnPack using PDMats using Bijectors using LinearAlgebra @@ -133,7 +134,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -150,15 +151,17 @@ n_dims = 10 σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); q0 = AVI.VIMeanFieldGaussian(μ, L) +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); + function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) diff --git a/docs/src/started.md b/docs/src/started.md index b07a5bd3..4e2b4380 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -27,6 +27,7 @@ ADVI with `Bijectors.Exp` bijectors is able to infer this model exactly. Using the `LogDensityProblems` interface, we the model can be defined as follows: ```@example advi using LogDensityProblems +using SimpleUnPack struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -36,7 +37,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -50,14 +51,14 @@ end ``` Let's now instantiate the model ```@example advi -using LinearAlgebra +using PDMats n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. @@ -66,7 +67,7 @@ Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to mat using Bijectors function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -93,6 +94,8 @@ objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi +using LinearAlgebra + d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5d0c3f8d..ae0dc684 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,6 +1,7 @@ module AdvancedVI +using SimpleUnPack: @unpack, @pack! using Accessors using Random: AbstractRNG, default_rng diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 73be42b9..3113c679 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -35,42 +35,42 @@ Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) function StatsBase.entropy(q::VILocationScale) - (; location, scale, dist) = q + @unpack location, scale, dist = q n_dims = length(location) n_dims*entropy(dist) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) - (; location, scale, dist) = q + @unpack location, scale, dist = q n_dims = length(location) scale*rand(dist, n_dims) + location end function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) - (; location, scale, dist) = q + @unpack location, scale, dist = q n_dims = length(location) scale*rand(rng, dist, n_dims, num_samples) .+ location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q rand!(rng, dist, x) x .= scale*x return x += location end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) - (; location, scale, dist) = q + @unpack location, scale, dist = q rand!(rng, dist, x) x *= scale return x += location diff --git a/test/Project.toml b/test/Project.toml index 277b73c8..2f38c88f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index c6aee68b..71cf22d5 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -8,6 +8,7 @@ using Optimisers using Distributions using PDMats using LinearAlgebra +using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M @@ -47,7 +48,7 @@ include("models/utils.jl") T = 10000 modelstats = modelconstr(realtype; rng) - (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats b = Bijectors.bijector(model) b⁻¹ = inverse(b) diff --git a/test/models/normal.jl b/test/models/normal.jl index 1dfa653c..f60ad5f3 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -5,7 +5,7 @@ struct TestMvNormal{M,S} end function LogDensityProblems.logdensity(model::TestMvNormal, θ) - (; μ, Σ) = model + @unpack μ, Σ = model logpdf(MvNormal(μ, Σ), θ) end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 49da5bf6..cab73cce 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -7,7 +7,7 @@ struct NormalLogNormal{MX,SX,MY,SY} end function LogDensityProblems.logdensity(model::NormalLogNormal, θ) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end @@ -20,7 +20,7 @@ function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) end function Bijectors.bijector(model::NormalLogNormal) - (; μ_x, σ_x, μ_y, Σ_y) = model + @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -56,8 +56,7 @@ function normallognormal_meanfield(realtype; rng = default_rng()) μ_y = randn(rng, realtype, n_dims) σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - #model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) - model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) μ = vcat(μ_x, μ_y) L = vcat(σ_x, σ_y) |> Diagonal diff --git a/test/optimize.jl b/test/optimize.jl index 96930495..c1a604c1 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -6,6 +6,7 @@ using Optimisers using Distributions using PDMats using LinearAlgebra +using SimpleUnPack: @unpack struct TestModel{M,L,S} model::M From 6ab95a096e058d21b9df1bb335d09381ce097705 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 21 Aug 2023 22:23:25 -0400 Subject: [PATCH 114/144] Revert "fix remove reamining use of `@unpack`" This reverts commit 817374403e58cb11e4e0e3aaee045c350d5bdfdc. --- test/optimize.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/optimize.jl b/test/optimize.jl index c1a604c1..920a3070 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -26,7 +26,7 @@ include("models/utils.jl") T = 1000 modelstats = normallognormal_meanfield(Float64; rng) - (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats # Global Test Configurations b⁻¹ = Bijectors.bijector(model) |> inverse From f0ec242e615fb9f3f7b4b05ea2a687fa9c0e8b0c Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 18:08:01 +0100 Subject: [PATCH 115/144] fix documentation for `optimize` --- src/optimize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 57ee8030..b18c8581 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,7 +35,7 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; t, est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. +- `callback!`: Callback function called after every iteration. The signature is `cb(; est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) # Returns @@ -80,7 +80,7 @@ function optimize( stat = merge(stat, stat′) if !isnothing(callback!) - stat′ = callback!(; est_state, stat, λ, g, restructure) + stat′ = callback!(; est_state, stat, restructure, λ, g) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end From 1d4c1b6877296a7bdca5ed38c9d34c5be3acc827 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 18:08:13 +0100 Subject: [PATCH 116/144] add specializations of `Optimise.destructure` for mean-field * This fixes the poor performance of `ForwardDiff` * This prevents the zero elements of the mean-field scale being extracted --- docs/src/locscale.md | 3 --- src/distributions/location_scale.jl | 35 ++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docs/src/locscale.md b/docs/src/locscale.md index 8f14a9ad..63ff5cb4 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -23,9 +23,6 @@ where ``\mathcal{H}(\varphi)`` is the entropy of the base distribution. Notice the ``\mathcal{H}(\varphi)`` does not depend on ``\log |C|``. The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. -!!! warning - `LocationScale` and its specializations such as `VIFullRankGaussian` and `VIMeanFieldGaussian` are inefficient with forward-mode differentiation packages like `ForwardDiff`. Especially, they scale poorly with the number of dimensions. Please use reverse-mode differentation packages such as `ReverseDiff` and `Zygote`. - ## Constructors ```@docs diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 3113c679..9ae749f2 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -19,9 +19,8 @@ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution dist ::D function VILocationScale(location::AbstractVector{<:Real}, - scale::Union{<:AbstractTriangular{<:Real}, - <:Diagonal{<:Real}}, - dist::ContinuousUnivariateDistribution) + scale ::Union{<:AbstractTriangular{<:Real}, <:Diagonal{<:Real}}, + dist ::ContinuousUnivariateDistribution) # Restricting all the arguments to have the same types creates problems # with dual-variable-based AD frameworks. @assert (length(location) == size(scale,1)) && (length(location) == size(scale,2)) @@ -31,6 +30,32 @@ end Functors.@functor VILocationScale (location, scale) +# Specialization of `Optimisers.destructure` for mean-field location-scale families. +# These are necessary because we only want to extract the diagonal elements of +# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD +# is very inefficient. +# begin +struct RestructureMeanField{L, S<:Diagonal, D} + q::VILocationScale{L, S, D} +end + +function (re::RestructureMeanField)(flat::AbstractVector) + n_dims = div(length(flat), 2) + location = first(flat, n_dims) + scale = Diagonal(last(flat, n_dims)) + VILocationScale(location, scale, re.q.dist) +end + +function Optimisers.destructure( + q::VILocationScale{L, <:Diagonal, D} +) where {L, D} + @unpack location, scale, dist = q + flat = vcat(location, diag(scale)) + n_dims = length(location) + flat, RestructureMeanField(q) +end +# end + Base.length(q::VILocationScale) = length(q.location) Base.size(q::VILocationScale) = size(q.location) @@ -42,12 +67,12 @@ end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) end function _logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(zᵢ -> logpdf(dist, zᵢ), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) end function rand(q::VILocationScale) From 231835f719f6fce86a4e0cf9935431b53cce75c7 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 20:01:41 +0100 Subject: [PATCH 117/144] add test for `Optimisers.destructure` specializations --- test/distributions.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/distributions.jl b/test/distributions.jl index 9b18d020..dcd20696 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,6 +1,7 @@ using ReTest using Distributions: _logpdf +using Optimisers @testset "distributions" begin @testset "$(string(covtype)) $(basedist) $(realtype)" for @@ -55,4 +56,15 @@ using Distributions: _logpdf @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) end end + + @testset "Diagonal destructure" for + n_dims = 10 + μ = zeros(n_dims) + L = ones(n_dims) + q = VIMeanFieldGaussian(μ, L |> Diagonal) + λ, re = Optimisers.destructure(q) + + @test length(λ) == 2*n_dims + @test q == re(λ) + end end From ea2d426c2c9b96de7d640e9ab0add3b4ae853892 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 21:21:54 +0100 Subject: [PATCH 118/144] add specialization of `rand` for meanfield resulting in faster AD --- src/distributions/location_scale.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 9ae749f2..7eb1f708 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -87,6 +87,16 @@ function rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) scale*rand(rng, dist, n_dims, num_samples) .+ location end +# This specialization improves AD performance of the sampling path +function rand( + rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int +) where {L, D} + @unpack location, scale, dist = q + n_dims = length(location) + scale_diag = diag(scale) + scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location +end + function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) @unpack location, scale, dist = q rand!(rng, dist, x) From 3033d75938b9d37408bbe081bf73c7954aff09cf Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 21:42:16 +0100 Subject: [PATCH 119/144] add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian` --- src/distributions/location_scale.jl | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 7eb1f708..a7d9fbe4 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -6,12 +6,15 @@ The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. It generally represents any distribution for which the sampling path can be -represented as the following: +represented as follows: ```julia d = length(location) u = rand(dist, d) z = scale*u + location ``` + +!!! note + For stable convergence, the initial scale needs to be sufficiently large. """ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution location::L @@ -112,21 +115,37 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) end """ - VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) + VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true) This constructs a multivariate Gaussian distribution with a full rank covariance matrix. """ -function VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}) where {T <: Real} +function VIFullRankGaussian( + μ::AbstractVector{T}, + L::AbstractTriangular{T}; + check_args::Bool = true +) where {T <: Real} + @assert isposdef(L) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end q_base = Normal{T}(zero(T), one(T)) VILocationScale(μ, L, q_base) end """ - VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) + VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}; check_args = true) This constructs a multivariate Gaussian distribution with a diagonal covariance matrix. """ -function VIMeanFieldGaussian(μ::AbstractVector{T}, L::Diagonal{T}) where {T <: Real} +function VIMeanFieldGaussian( + μ::AbstractVector{T}, + L::Diagonal{T}; + check_args::Bool = true +) where {T <: Real} + @assert isposdef(L) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end q_base = Normal{T}(zero(T), one(T)) VILocationScale(μ, L, q_base) end From 0cc36c0eb9f4fc701e73e5ee835e5e0ced0c88d1 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 21:55:02 +0100 Subject: [PATCH 120/144] update documentation --- docs/src/advi.md | 5 ++--- docs/src/locscale.md | 4 ++++ src/distributions/location_scale.jl | 3 --- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index afb780cb..2cf6a773 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -210,8 +210,8 @@ _, stats_stl, _ = AVI.optimize( t = [stat.iteration for stat ∈ stats_cfe] y_cfe = [stat.elbo for stat ∈ stats_cfe] y_stl = [stat.elbo for stat ∈ stats_stl] -plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO") -plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO") +plot( t, y_cfe, label="ADVI CFE", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) +plot!(t, y_stl, label="ADVI STL", xlabel="Iteration", ylabel="ELBO", ylims=(-50, 10)) savefig("advi_stl_elbo.svg") nothing ``` @@ -220,7 +220,6 @@ nothing We can see that the noise of the STL estimator becomes smaller as VI converges. However, the speed of convergence may not always be significantly different. - ## References 1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. 2. Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. diff --git a/docs/src/locscale.md b/docs/src/locscale.md index 63ff5cb4..a5966f44 100644 --- a/docs/src/locscale.md +++ b/docs/src/locscale.md @@ -25,6 +25,10 @@ The derivative of the entropy with respect to ``\lambda`` is thus independent of ## Constructors +!!! note + For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. + Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. + ```@docs VILocationScale ``` diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index a7d9fbe4..ce14d724 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -12,9 +12,6 @@ represented as follows: u = rand(dist, d) z = scale*u + location ``` - -!!! note - For stable convergence, the initial scale needs to be sufficiently large. """ struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution location::L From b7d3471fdd81b44a07dac068f1d84a260bb4959a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:54:18 +0100 Subject: [PATCH 121/144] fix type instability, bug in argument check in `LocationScale` --- src/distributions/location_scale.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index ce14d724..ab12db84 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -57,12 +57,15 @@ end # end Base.length(q::VILocationScale) = length(q.location) + Base.size(q::VILocationScale) = size(q.location) +Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D) + function StatsBase.entropy(q::VILocationScale) @unpack location, scale, dist = q n_dims = length(location) - n_dims*entropy(dist) + first(logabsdet(scale)) + n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) end function logpdf(q::VILocationScale, z::AbstractVector{<:Real}) @@ -121,7 +124,7 @@ function VIFullRankGaussian( L::AbstractTriangular{T}; check_args::Bool = true ) where {T <: Real} - @assert isposdef(L) "Scale must be positive definite" + @assert eigmin(L) > eps(eltype(L)) "Scale must be positive definite" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end @@ -139,7 +142,7 @@ function VIMeanFieldGaussian( L::Diagonal{T}; check_args::Bool = true ) where {T <: Real} - @assert isposdef(L) "Scale must be positive definite" + @assert eigmin(L) > eps(eltype(L)) "Scale must be a Cholesky factor" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end From df50e8346e2d3174c6e57f41812e25f5d9c9751e Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:57:24 +0100 Subject: [PATCH 122/144] add missing import bug --- src/AdvancedVI.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index ae0dc684..16807542 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -7,7 +7,7 @@ using Accessors using Random: AbstractRNG, default_rng using Distributions import Distributions: - logpdf, _logpdf, rand, _rand!, _rand!, + logpdf, _logpdf, rand, rand!, _rand!, ContinuousMultivariateDistribution using Functors @@ -26,7 +26,6 @@ using ADTypes: AbstractADType using ChainRules: @ignore_derivatives using FillArrays -using PDMats using Bijectors using StatsBase From ae3e9b018518b803ed60b6eaf7c5400cdf040a10 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:57:43 +0100 Subject: [PATCH 123/144] refactor test, fix type bug in tests for `LocationScale` --- test/ad.jl | 2 -- test/advi_locscale.jl | 24 +++--------------------- test/distributions.jl | 27 ++++++++++++--------------- test/models/utils.jl | 8 -------- test/optimize.jl | 18 ------------------ test/runtests.jl | 23 +++++++++++++++++++++++ 6 files changed, 38 insertions(+), 64 deletions(-) delete mode 100644 test/models/utils.jl diff --git a/test/ad.jl b/test/ad.jl index 2c4f802a..f575b485 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,7 +1,5 @@ using ReTest -using ForwardDiff, ReverseDiff, Enzyme, Zygote -using ADTypes @testset "ad" begin @testset "$(adname)" for (adname, adsymbol) ∈ Dict( diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 71cf22d5..a7dcc98b 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -2,25 +2,6 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using ReTest -using Bijectors -using LogDensityProblems -using Optimisers -using Distributions -using PDMats -using LinearAlgebra -using SimpleUnPack: @unpack - -struct TestModel{M,L,S} - model::M - μ_true::L - L_true::S - n_dims::Int - is_meanfield::Bool -end - -include("models/normallognormal.jl") -include("models/normal.jl") -include("models/utils.jl") @testset "advi" begin @testset "locscale" begin @@ -55,10 +36,11 @@ include("models/utils.jl") μ₀ = zeros(realtype, n_dims) L₀ = if is_meanfield - ones(realtype, n_dims) |> Diagonal + FillArrays.Eye(n_dims) |> Diagonal else - diagm(ones(realtype, n_dims)) |> LowerTriangular + FillArrays.Eye(n_dims) |> LowerTriangular end + q₀ = if is_meanfield VIMeanFieldGaussian(μ₀, L₀) else diff --git a/test/distributions.jl b/test/distributions.jl index dcd20696..563de12d 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -1,7 +1,6 @@ using ReTest using Distributions: _logpdf -using Optimisers @testset "distributions" begin @testset "$(string(covtype)) $(basedist) $(realtype)" for @@ -11,35 +10,33 @@ using Optimisers seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) rng = Philox4x(UInt64, seed, 8) - realtype = Float64 - ϵ = 1f-2 n_dims = 10 n_montecarlo = 1000_000 - μ = randn(rng, realtype, n_dims) - L₀ = randn(rng, realtype, n_dims, n_dims) |> LowerTriangular - Σ = if covtype == :fullrank - Σ = (L₀*L₀' + ϵ*I) |> Hermitian + μ = randn(rng, realtype, n_dims) + L = if covtype == :fullrank + sample_cholesky(rng, realtype, n_dims) else Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) end + Σ = L*L' - L = cholesky(Σ).L q = if covtype == :fullrank && basedist == :gaussian - VIFullRankGaussian(μ, L |> LowerTriangular) + VIFullRankGaussian(μ, L) elseif covtype == :meanfield && basedist == :gaussian - VIMeanFieldGaussian(μ, L |> Diagonal) + VIMeanFieldGaussian(μ, L) end q_true = if basedist == :gaussian MvNormal(μ, Σ) end @testset "logpdf" begin - z = randn(rng, realtype, n_dims) - @test logpdf(q, z) ≈ logpdf(q_true, z) - @test _logpdf(q, z) ≈ _logpdf(q_true, z) - @test eltype(logpdf(q, z)) == realtype - @test eltype(_logpdf(q, z)) == realtype + z = rand(rng, q) + @test eltype(z) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test _logpdf(q, z) ≈ _logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + @test eltype(_logpdf(q, z)) == realtype end @testset "entropy" begin diff --git a/test/models/utils.jl b/test/models/utils.jl deleted file mode 100644 index 3d483c46..00000000 --- a/test/models/utils.jl +++ /dev/null @@ -1,8 +0,0 @@ - -function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) - A = randn(rng, type, n_dims, n_dims) - L = tril(A) - idx = diagind(L) - @. L[idx] = log(exp(L[idx]) + 1) - L |> LowerTriangular -end diff --git a/test/optimize.jl b/test/optimize.jl index 920a3070..5686b724 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -1,23 +1,5 @@ using ReTest -using Bijectors -using LogDensityProblems -using Optimisers -using Distributions -using PDMats -using LinearAlgebra -using SimpleUnPack: @unpack - -struct TestModel{M,L,S} - model::M - μ_true::L - L_true::S - n_dims::Int - is_meanfield::Bool -end - -include("models/normallognormal.jl") -include("models/utils.jl") @testset "optimize" begin seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) diff --git a/test/runtests.jl b/test/runtests.jl index 6bd3bc49..803c11c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,9 +8,32 @@ using Random123 using Statistics using Distributions using LinearAlgebra +using SimpleUnPack: @unpack +using PDMats + +using Bijectors +using LogDensityProblems +using Optimisers +using ADTypes +using ForwardDiff, ReverseDiff, Zygote using AdvancedVI +# Utilities +include("utils.jl") + +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool +end + +include("models/normal.jl") +include("models/normallognormal.jl") + +# Tests include("ad.jl") include("distributions.jl") include("advi_locscale.jl") From e4002cfeb0f8edd7dd8cf02e6ee68f1eb2bf959a Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 22:58:08 +0100 Subject: [PATCH 124/144] add missing compat entries --- Project.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e099308a..87aa4aac 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,6 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -37,17 +36,21 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" +Accessors = "0.1.32" Bijectors = "0.11, 0.12, 0.13" ChainRules = "1.53.0" DiffResults = "1" Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" +FillArrays = "1.6.0" ForwardDiff = "0.10.25" +Functors = "0.4.5" LogDensityProblems = "2.1.1" Optimisers = "0.2.16" ProgressMeter = "1.0.0" Requires = "0.5, 1.0" ReverseDiff = "1.14" +SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" julia = "1.6" From 8c82569208199480676de7b583cd54ff079ba8c5 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:19:26 +0100 Subject: [PATCH 125/144] fix missing package import in test --- test/Project.toml | 1 + test/runtests.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 2f38c88f..663d671d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/runtests.jl b/test/runtests.jl index 803c11c7..8a6e486e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using Distributions using LinearAlgebra using SimpleUnPack: @unpack using PDMats +using FillArrays using Bijectors using LogDensityProblems From c2e751723a63cd00b5f223a390dc34513b94b946 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:19:34 +0100 Subject: [PATCH 126/144] add additional tests for sampling `LocationScale` --- test/distributions.jl | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/test/distributions.jl b/test/distributions.jl index 563de12d..c603421e 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -31,6 +31,9 @@ using Distributions: _logpdf end @testset "logpdf" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + z = rand(rng, q) @test eltype(z) == realtype @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) @@ -45,12 +48,38 @@ using Distributions: _logpdf end @testset "sampling" begin - z_samples = rand(rng, q, n_montecarlo) - threesigma = L - @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) - @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) - @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + @testset "rand" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = mapreduce(x -> rand(rng, q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand batch" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = rand(rng, q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand!" begin + seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) + rng = Philox4x(UInt64, seed, 8) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(rng, q, z_samples) + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end end end From 3a6f8bf689af5657d817674d84a886d3496864d6 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:19:50 +0100 Subject: [PATCH 127/144] fix bug in batch in-place `rand!` for `LocationScale` --- src/distributions/location_scale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index ab12db84..ecb0b672 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -110,8 +110,8 @@ end function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) @unpack location, scale, dist = q rand!(rng, dist, x) - x *= scale - return x += location + x[:] = scale*x + return x .+= location end """ From b78ef4bf3afe6649d124320595edde71d3031e02 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Tue, 22 Aug 2023 23:39:16 +0100 Subject: [PATCH 128/144] fix bug in inference test initialization --- test/advi_locscale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index a7dcc98b..76ae3724 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -38,7 +38,7 @@ using ReTest L₀ = if is_meanfield FillArrays.Eye(n_dims) |> Diagonal else - FillArrays.Eye(n_dims) |> LowerTriangular + FillArrays.Eye(n_dims) |> Matrix |> LowerTriangular end q₀ = if is_meanfield From a1f7e98a612bc8e7b840c4c341ee8870aac9e29f Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 23 Aug 2023 01:29:50 +0100 Subject: [PATCH 129/144] add missing file --- test/utils.jl | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 test/utils.jl diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..3d483c46 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,8 @@ + +function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) + A = randn(rng, type, n_dims, n_dims) + L = tril(A) + idx = diagind(L) + @. L[idx] = log(exp(L[idx]) + 1) + L |> LowerTriangular +end From 8b783eca14a21cc620f311f9f63417e9f31e5de8 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 22 Aug 2023 21:46:01 -0400 Subject: [PATCH 130/144] fix remove use of for 1.6 --- src/distributions/location_scale.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index ecb0b672..91b6768a 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -124,7 +124,7 @@ function VIFullRankGaussian( L::AbstractTriangular{T}; check_args::Bool = true ) where {T <: Real} - @assert eigmin(L) > eps(eltype(L)) "Scale must be positive definite" + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end @@ -142,7 +142,7 @@ function VIMeanFieldGaussian( L::Diagonal{T}; check_args::Bool = true ) where {T <: Real} - @assert eigmin(L) > eps(eltype(L)) "Scale must be a Cholesky factor" + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end From 12cd9f22611f3bf1a95ea878ade7c3f151957cd9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Wed, 23 Aug 2023 21:00:51 +0100 Subject: [PATCH 131/144] refactor adjust inference test hyperparameters to be more robust --- test/advi_locscale.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 76ae3724..524dc5e2 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -27,10 +27,11 @@ using ReTest seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) rng = Philox4x(UInt64, seed, 8) - T = 10000 modelstats = modelconstr(realtype; rng) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + b = Bijectors.bijector(model) b⁻¹ = inverse(b) @@ -53,7 +54,7 @@ using ReTest Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(1e-2), + optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -72,7 +73,7 @@ using ReTest rng = Philox4x(UInt64, seed, 8) q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), + optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng, adbackend = adbackend, @@ -83,7 +84,7 @@ using ReTest rng_repl = Philox4x(UInt64, seed, 8) q, stats, _ = optimize( obj, q₀, T; - optimizer = Optimisers.Adam(realtype(1e-2)), + optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, rng = rng_repl, adbackend = adbackend, From 837c7296467ae20c66f7c061a6142295ebe50b22 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 24 Aug 2023 02:43:39 +0100 Subject: [PATCH 132/144] refactor `optimize` to return `obj_state`, add warm start kwargs --- docs/src/advi.md | 4 +-- docs/src/started.md | 4 +-- src/AdvancedVI.jl | 6 ----- src/objectives/elbo/advi.jl | 49 +++++++++++++++++-------------------- src/optimize.jl | 27 +++++++++++++------- test/advi_locscale.jl | 12 ++++----- test/optimize.jl | 14 +++++------ 7 files changed, 57 insertions(+), 59 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 2cf6a773..3ac90436 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -189,7 +189,7 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i ```@setup stl n_max_iter = 10^4 -_, stats_cfe, _ = AVI.optimize( +_, stats_cfe, _, _ = AVI.optimize( cfe, q0, n_max_iter; @@ -198,7 +198,7 @@ _, stats_cfe, _ = AVI.optimize( optimizer = Optimisers.Adam(1e-3) ); -_, stats_stl, _ = AVI.optimize( +_, stats_stl, _, _ = AVI.optimize( stl, q0, n_max_iter; diff --git a/docs/src/started.md b/docs/src/started.md index 4e2b4380..f3ae54b1 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -103,8 +103,8 @@ q = AVI.VIMeanFieldGaussian(μ, L) ``` Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. ```@example advi -n_max_iter = 10^4 -q, stats, _ = AVI.optimize( +n_max_iter = 10^4 +q, stats, _, _ = AVI.optimize( objective, q, n_max_iter; diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 16807542..9bc3d316 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -53,14 +53,8 @@ abstract type AbstractVariationalObjective end function init end function estimate_gradient end -init(::Nothing) = nothing - # ADVI-specific interfaces abstract type AbstractEntropyEstimator end -abstract type AbstractControlVariate end - -function update end -update(::Nothing, ::Nothing) = (nothing, nothing) # entropy.jl must preceed advi.jl include("objectives/elbo/entropy.jl") diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index d8719fa7..f9a61d81 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -19,18 +19,15 @@ Automatic differentiation variational inference (ADVI; Kucukelbir *et al.* 2017) Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ -struct ADVI{Tlogπ, B, - EntropyEst <: AbstractEntropyEstimator, - ControlVar <: Union{<: AbstractControlVariate, Nothing}} <: AbstractVariationalObjective - ℓπ::Tlogπ - invbij::B - entropy::EntropyEst - cv::ControlVar +struct ADVI{P, B, EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + prob ::P + invbij ::B + entropy ::EntropyEst n_samples::Int - function ADVI(prob, n_samples::Int; - entropy::AbstractEntropyEstimator = ClosedFormEntropy(), - cv::Union{<:AbstractControlVariate, Nothing} = nothing, + function ADVI(prob, + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy(), invbij = Bijectors.identity) cap = LogDensityProblems.capabilities(prob) if cap === nothing @@ -40,15 +37,16 @@ struct ADVI{Tlogπ, B, ), ) end - ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) - new{typeof(ℓπ), typeof(invbij), typeof(entropy), typeof(cv)}(ℓπ, invbij, entropy, cv, n_samples) + new{typeof(prob), typeof(invbij), typeof(entropy)}( + prob, invbij, entropy, n_samples + ) end end Base.show(io::IO, advi::ADVI) = - print(io, "ADVI(entropy=$(advi.entropy), cv=$(advi.cv), n_samples=$(advi.n_samples))") + print(io, "ADVI(entropy=$(advi.entropy), n_samples=$(advi.n_samples))") -init(advi::ADVI) = init(advi.cv) +init(rng::AbstractRNG, advi::ADVI, λ::AbstractVector, restructure) = nothing function (advi::ADVI)( rng::AbstractRNG, @@ -57,7 +55,7 @@ function (advi::ADVI)( ) 𝔼ℓ = mean(eachcol(ηs)) do ηᵢ zᵢ, logdetjacᵢ = Bijectors.with_logabsdet_jacobian(advi.invbij, ηᵢ) - advi.ℓπ(zᵢ) + logdetjacᵢ + LogDensityProblems.logdensity(advi.prob, zᵢ) + logdetjacᵢ end ℍ = advi.entropy(q_η, ηs) 𝔼ℓ + ℍ @@ -78,22 +76,22 @@ Evaluate the ELBO using the ADVI formulation. """ function (advi::ADVI)( - q_η::ContinuousMultivariateDistribution; - rng::AbstractRNG = default_rng(), - n_samples::Int = advi.n_samples + q_η ::ContinuousMultivariateDistribution; + rng ::AbstractRNG = default_rng(), + n_samples::Int = advi.n_samples ) ηs = rand(rng, q_η, n_samples) advi(rng, q_η, ηs) end function estimate_gradient( - rng::AbstractRNG, - adbackend::AbstractADType, - advi::ADVI, + rng ::AbstractRNG, + adbackend ::AbstractADType, + advi ::ADVI, est_state, - λ::Vector{<:Real}, + λ ::Vector{<:Real}, restructure, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) f(λ′) = begin q_η = restructure(λ′) @@ -105,8 +103,5 @@ function estimate_gradient( nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - est_state, stat′ = update(advi.cv, est_state) - stat = !isnothing(stat′) ? merge(stat′, stat) : stat - - out, est_state, stat + out, nothing, stat end diff --git a/src/optimize.jl b/src/optimize.jl index b18c8581..54e7ace0 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -35,13 +35,18 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; est_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `est_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. +- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) +When resuming from the state of a previous run, use the following keyword arguments: +- `opt_state`: Initial state of the optimizer. +- `obj_state`: Initial state of the objective. + # Returns - `λ`: Variational parameters optimizing the variational objective. - `stats`: Statistics gathered during inference. - `opt_state`: Final state of the optimiser. +- `obj_state`: Final state of the objective. """ function optimize( objective ::AbstractVariationalObjective, @@ -52,6 +57,8 @@ function optimize( optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::AbstractRNG = default_rng(), show_progress::Bool = true, + opt_state = nothing, + obj_state = nothing, callback! = nothing, prog = ProgressMeter.Progress( n_max_iter; @@ -62,16 +69,16 @@ function optimize( ) ) λ = copy(λ₀) - opt_state = Optimisers.setup(optimizer, λ) - est_state = init(objective) + opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ) : opt_state + obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state grad_buf = DiffResults.GradientResult(λ) stats = NamedTuple[] for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, est_state, stat′ = estimate_gradient( - rng, adbackend, objective, est_state, λ, restructure, grad_buf) + grad_buf, obj_state, stat′ = estimate_gradient( + rng, adbackend, objective, obj_state, λ, restructure, grad_buf) stat = merge(stat, stat′) g = DiffResults.gradient(grad_buf) @@ -80,7 +87,7 @@ function optimize( stat = merge(stat, stat′) if !isnothing(callback!) - stat′ = callback!(; est_state, stat, restructure, λ, g) + stat′ = callback!(; obj_state, stat, restructure, λ, g) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end @@ -89,7 +96,7 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) end - λ, map(identity, stats), opt_state + λ, map(identity, stats), opt_state, obj_state end function optimize(objective ::AbstractVariationalObjective, @@ -97,6 +104,8 @@ function optimize(objective ::AbstractVariationalObjective, n_max_iter::Int; kwargs...) λ, restructure = Optimisers.destructure(q₀) - λ, stats, opt_state = optimize(objective, restructure, λ, n_max_iter; kwargs...) - restructure(λ), stats, opt_state + λ, stats, opt_state, obj_state = optimize( + objective, restructure, λ, n_max_iter; kwargs... + ) + restructure(λ), stats, opt_state, obj_state end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index 524dc5e2..e780b074 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -51,8 +51,8 @@ using ReTest obj = objective(model, b⁻¹, 10) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _ = optimize( + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, @@ -70,8 +70,8 @@ using ReTest end @testset "determinism" begin - rng = Philox4x(UInt64, seed, 8) - q, stats, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + q, stats, _, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, @@ -81,8 +81,8 @@ using ReTest μ = q.location L = q.scale - rng_repl = Philox4x(UInt64, seed, 8) - q, stats, _ = optimize( + rng_repl = Philox4x(UInt64, seed, 8) + q, stats, _, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, diff --git a/test/optimize.jl b/test/optimize.jl index 5686b724..2369432c 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -20,8 +20,8 @@ using ReTest adbackend = AutoForwardDiff() optimizer = Optimisers.Adam(1e-2) - rng = Philox4x(UInt64, seed, 8) - q_ref, stats_ref, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + q_ref, stats_ref, _, _ = optimize( obj, q₀, T; optimizer, show_progress = false, @@ -33,8 +33,8 @@ using ReTest @testset "restructure" begin λ₀, re = Optimisers.destructure(q₀) - rng = Philox4x(UInt64, seed, 8) - λ, stats, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + λ, stats, _, _ = optimize( obj, re, λ₀, T; optimizer, show_progress = false, @@ -49,12 +49,12 @@ using ReTest rng = Philox4x(UInt64, seed, 8) test_values = rand(rng, T) - callback!(; stat, est_state, restructure, λ, g) = begin + callback!(; stat, obj_state, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end - rng = Philox4x(UInt64, seed, 8) - _, stats, _ = optimize( + rng = Philox4x(UInt64, seed, 8) + _, stats, _, _ = optimize( obj, q₀, T; show_progress = false, rng, From 95629a5471f7e3e94a19b8096cd9df73d8dad523 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 23 Aug 2023 23:19:09 -0400 Subject: [PATCH 133/144] refactor make tests more robust, reduce amount of tests --- test/advi_locscale.jl | 2 -- test/distributions.jl | 2 +- test/models/normal.jl | 50 ---------------------------------- test/models/normallognormal.jl | 2 +- test/runtests.jl | 5 +--- test/utils.jl | 8 ------ 6 files changed, 3 insertions(+), 66 deletions(-) delete mode 100644 test/models/normal.jl delete mode 100644 test/utils.jl diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index e780b074..d5250ce8 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -10,8 +10,6 @@ using ReTest (modelname, modelconstr) ∈ Dict( :NormalLogNormalMeanField => normallognormal_meanfield, :NormalLogNormalFullRank => normallognormal_fullrank, - :NormalMeanField => normal_meanfield, - :NormalFullRank => normal_fullrank, ), (objname, objective) ∈ Dict( :ADVIClosedFormEntropy => (model, b⁻¹, M) -> ADVI(model, M; invbij = b⁻¹), diff --git a/test/distributions.jl b/test/distributions.jl index c603421e..175cc96b 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -15,7 +15,7 @@ using Distributions: _logpdf μ = randn(rng, realtype, n_dims) L = if covtype == :fullrank - sample_cholesky(rng, realtype, n_dims) + tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular else Diagonal(log.(exp.(randn(rng, realtype, n_dims)) .+ 1)) end diff --git a/test/models/normal.jl b/test/models/normal.jl deleted file mode 100644 index f60ad5f3..00000000 --- a/test/models/normal.jl +++ /dev/null @@ -1,50 +0,0 @@ - -struct TestMvNormal{M,S} - μ::M - Σ::S -end - -function LogDensityProblems.logdensity(model::TestMvNormal, θ) - @unpack μ, Σ = model - logpdf(MvNormal(μ, Σ), θ) -end - -function LogDensityProblems.dimension(model::TestMvNormal) - length(model.μ) -end - -function LogDensityProblems.capabilities(::Type{<:TestMvNormal}) - LogDensityProblems.LogDensityOrder{0}() -end - -function Bijectors.bijector(model::TestMvNormal) - identity -end - -function normal_fullrank(realtype; rng = default_rng()) - n_dims = 5 - - μ = randn(rng, realtype, n_dims) - L₀ = sample_cholesky(rng, realtype, n_dims) - Σ = L₀*L₀' |> Hermitian - - Σ_chol = cholesky(Σ) - model = TestMvNormal(μ, PDMats.PDMat(Σ, Σ_chol)) - - L = Σ_chol.L |> LowerTriangular - - TestModel(model, μ, L, n_dims, false) -end - -function normal_meanfield(realtype; rng = default_rng()) - n_dims = 5 - - μ = randn(rng, realtype, n_dims) - σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - - model = TestMvNormal(μ, PDMats.PDiagMat(σ)) - - L = σ |> Diagonal - - TestModel(model, μ, L, n_dims, true) -end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index cab73cce..f8b84a1b 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,7 +32,7 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L₀_y = sample_cholesky(rng, realtype, n_dims) + L₀_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular Σ_y = L₀_y*L₀_y' |> Hermitian model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) diff --git a/test/runtests.jl b/test/runtests.jl index 8a6e486e..0a2c5e66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,9 +20,7 @@ using ForwardDiff, ReverseDiff, Zygote using AdvancedVI -# Utilities -include("utils.jl") - +# Models for Inference Tests struct TestModel{M,L,S} model::M μ_true::L @@ -31,7 +29,6 @@ struct TestModel{M,L,S} is_meanfield::Bool end -include("models/normal.jl") include("models/normallognormal.jl") # Tests diff --git a/test/utils.jl b/test/utils.jl deleted file mode 100644 index 3d483c46..00000000 --- a/test/utils.jl +++ /dev/null @@ -1,8 +0,0 @@ - -function sample_cholesky(rng::AbstractRNG, type::Type, n_dims::Int) - A = randn(rng, type, n_dims, n_dims) - L = tril(A) - idx = diagind(L) - @. L[idx] = log(exp(L[idx]) + 1) - L |> LowerTriangular -end From 0b4b865ae9376b35b776afca17baf58cea27b095 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:31:09 -0400 Subject: [PATCH 134/144] fix remove a cholesky in test model --- test/models/normallognormal.jl | 14 +++++++------- test/runtests.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index f8b84a1b..ec591f2c 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -29,13 +29,13 @@ end function normallognormal_fullrank(realtype; rng = default_rng()) n_dims = 5 - μ_x = randn(rng, realtype) - σ_x = ℯ - μ_y = randn(rng, realtype, n_dims) - L₀_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular - Σ_y = L₀_y*L₀_y' |> Hermitian + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular + Σ_y = L_y*L_y' |> Hermitian - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDMat(Σ_y)) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y))) Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) Σ[1,1] = σ_x^2 @@ -56,7 +56,7 @@ function normallognormal_meanfield(realtype; rng = default_rng()) μ_y = randn(rng, realtype, n_dims) σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) μ = vcat(μ_x, μ_y) L = vcat(σ_x, σ_y) |> Diagonal diff --git a/test/runtests.jl b/test/runtests.jl index 0a2c5e66..127503be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,8 +9,8 @@ using Statistics using Distributions using LinearAlgebra using SimpleUnPack: @unpack -using PDMats using FillArrays +using PDMats using Bijectors using LogDensityProblems From b49f4ebc163e2feecba38fba2678e650dfbd788d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:31:34 -0400 Subject: [PATCH 135/144] fix compat bounds, remove unused package --- Project.toml | 28 ++++++++++++++-------------- src/AdvancedVI.jl | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 87aa4aac..143e2098 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,7 @@ version = "0.3.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -20,7 +20,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -36,23 +35,24 @@ AdvancedVIZygoteExt = "Zygote" [compat] ADTypes = "0.1" -Accessors = "0.1.32" -Bijectors = "0.11, 0.12, 0.13" -ChainRules = "1.53.0" +Accessors = "0.1" +Bijectors = "0.12, 0.13" +ChainRulesCore = "1.16" DiffResults = "1" -Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" +Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" -FillArrays = "1.6.0" -ForwardDiff = "0.10.25" -Functors = "0.4.5" -LogDensityProblems = "2.1.1" +Enzyme = "0.11.7" +FillArrays = "1.3" +ForwardDiff = "0.10.36" +Functors = "0.4" +LogDensityProblems = "2" Optimisers = "0.2.16" -ProgressMeter = "1.0.0" -Requires = "0.5, 1.0" -ReverseDiff = "1.14" +ProgressMeter = "1.6" +Requires = "1.0" +ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -StatsFuns = "0.8, 0.9, 1" +Zygote = "0.6.63" julia = "1.6" [extras] diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9bc3d316..7272303a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -23,7 +23,7 @@ using LogDensityProblems using ADTypes, DiffResults using ADTypes: AbstractADType -using ChainRules: @ignore_derivatives +using ChainRulesCore: @ignore_derivatives using FillArrays using Bijectors From 947a070da945505282711f6a45f6c3723b32b7fd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:32:51 -0400 Subject: [PATCH 136/144] bump compat for ADTypes 0.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 143e2098..075ae92f 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ AdvancedVIReverseDiffExt = "ReverseDiff" AdvancedVIZygoteExt = "Zygote" [compat] -ADTypes = "0.1" +ADTypes = "0.1, 0.2" Accessors = "0.1" Bijectors = "0.12, 0.13" ChainRulesCore = "1.16" From a9b3f483f4ae3bd4ac2d569d21697c8a786c448c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:35:32 -0400 Subject: [PATCH 137/144] fix broken LaTeX in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 07407fa9..86a57cb6 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ x &\sim \mathrm{LogNormal}\left(\mu_x, \sigma_x^2\right) \\ y &\sim \mathcal{N}\left(\mu_y, \sigma_y^2\right), \end{aligned} $$ + a `LogDensityProblem` can be implemented as ```julia using LogDensityProblems From 54826eb51c0a64bd7fd85b9363300c28e77381d7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 00:52:35 -0400 Subject: [PATCH 138/144] remove redundant use of PDMats in docs --- README.md | 9 ++++----- docs/Project.toml | 1 - docs/src/advi.md | 5 +---- docs/src/started.md | 6 ++---- 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 86a57cb6..695e9ed9 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ This corresponds to the automatic differentiation VI (ADVI; Kucukelbir *et al.*, using Bijectors function Bijectors.bijector(model::NormalLogNormal) - @unpack μ_x, σ_x, μ_y, Σ_y = model + (; μ_x, σ_x, μ_y, Σ_y) = model Bijectors.Stacked( Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), [1:1, 2:1+length(μ_y)]) @@ -60,19 +60,18 @@ A simpler approach is to use `Turing`, where a `Turing.Model` can be automatical Let us instantiate a random normal-log-normal model. ```julia -using PDMats +using LinearAlgebra n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)) +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) ``` ADVI can be used as follows: ```julia -using LinearAlgebra using Optimisers using ADTypes, ForwardDiff import AdvancedVI as AVI @@ -81,7 +80,7 @@ b = Bijectors.bijector(model) b⁻¹ = inverse(b) # ADVI objective -objective = AVI.ADVI(model, 10; b=b⁻¹) +objective = AVI.ADVI(model, 10; invbij=b⁻¹) # Mean-field Gaussian variational family d = LogDensityProblems.dimension(model) diff --git a/docs/Project.toml b/docs/Project.toml index 182edd3e..568be1b6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,7 +7,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" diff --git a/docs/src/advi.md b/docs/src/advi.md index 3ac90436..2773dda7 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -117,7 +117,6 @@ StickingTheLandingEntropy ```@setup stl using LogDensityProblems using SimpleUnPack -using PDMats using Bijectors using LinearAlgebra using Plots @@ -151,15 +150,13 @@ n_dims = 10 σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); q0 = AVI.VIMeanFieldGaussian(μ, L) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); - function Bijectors.bijector(model::NormalLogNormal) @unpack μ_x, σ_x, μ_y, Σ_y = model Bijectors.Stacked( diff --git a/docs/src/started.md b/docs/src/started.md index f3ae54b1..e8392fd7 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -51,14 +51,14 @@ end ``` Let's now instantiate the model ```@example advi -using PDMats +using LinearAlgebra n_dims = 10 μ_x = randn() σ_x = exp.(randn()) μ_y = randn(n_dims) σ_y = exp.(randn(n_dims)) -model = NormalLogNormal(μ_x, σ_x, μ_y, PDMats.PDiagMat(σ_y.^2)); +model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)); ``` Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. @@ -94,8 +94,6 @@ objective = AVI.ADVI(model, n_montecaro; invbij = b⁻¹) ``` For the variational family, we will use the classic mean-field Gaussian family. ```@example advi -using LinearAlgebra - d = LogDensityProblems.dimension(model); μ = randn(d); L = Diagonal(ones(d)); From 1d1c8ffd320463b6bd9a552227270bb2837344b0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 01:09:27 -0400 Subject: [PATCH 139/144] fix use `Cholesky` signature supported in 1.6 --- test/models/normallognormal.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index ec591f2c..e2b9e816 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -32,10 +32,10 @@ function normallognormal_fullrank(realtype; rng = default_rng()) μ_x = randn(rng, realtype) σ_x = ℯ μ_y = randn(rng, realtype, n_dims) - L_y = tril(I + ones(realtype, n_dims, n_dims))/2 |> LowerTriangular + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 Σ_y = L_y*L_y' |> Hermitian - model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y))) + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) Σ[1,1] = σ_x^2 From a0de2cf4e3f9665bd22fe8957d0076770e5411ab Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 01:51:18 -0400 Subject: [PATCH 140/144] fix remove redundant cholesky operation in test --- test/models/normallognormal.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index e2b9e816..b8d72cc0 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -37,14 +37,11 @@ function normallognormal_fullrank(realtype; rng = default_rng()) model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) - Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) - Σ[1,1] = σ_x^2 - Σ[2:end,2:end] = Σ_y - Σ = Σ |> Hermitian + L = Matrix{realtype}(undef, n_dims+1, n_dims+1) |> LowerTriangular + L[1,1] = σ_x + L[2:end,2:end] = L_y μ = vcat(μ_x, μ_y) - L = cholesky(Σ).L |> LowerTriangular - TestModel(model, μ, L, n_dims+1, false) end From f593a67735a5ea60bfe19e4b107594ca57839c47 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 24 Aug 2023 02:18:29 -0400 Subject: [PATCH 141/144] add `mean`, `var`, `cov` to `LocationScale` --- src/distributions/location_scale.jl | 12 ++++++++++++ test/distributions.jl | 15 +++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/distributions/location_scale.jl b/src/distributions/location_scale.jl index 91b6768a..c290b81a 100644 --- a/src/distributions/location_scale.jl +++ b/src/distributions/location_scale.jl @@ -114,6 +114,18 @@ function _rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) return x .+= location end +Distributions.mean(q::VILocationScale) = q.location + +function Distributions.var(q::VILocationScale) + C = q.scale + Diagonal(C*C') +end + +function Distributions.cov(q::VILocationScale) + C = q.scale + Hermitian(C*C') +end + """ VIFullRankGaussian(μ::AbstractVector{T}, L::AbstractTriangular{T}; check_args = true) diff --git a/test/distributions.jl b/test/distributions.jl index 175cc96b..9cb158c1 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -47,6 +47,21 @@ using Distributions: _logpdf @test entropy(q) ≈ entropy(q_true) end + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) == μ + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ + end + end + @testset "sampling" begin @testset "rand" begin seed = (0x38bef07cf9cc549d, 0x49e2430080b3f797) From ff32ac642d6aa3a08d371ed895aa6b4026b06b92 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 24 Aug 2023 21:06:56 +0100 Subject: [PATCH 142/144] refactor `optimize` warm-starting interface, add `objargs` argument --- src/optimize.jl | 64 ++++++++++++++++++++++--------------------- test/advi_locscale.jl | 6 ++-- test/optimize.jl | 34 ++++++++++++++++++++--- 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/src/optimize.jl b/src/optimize.jl index 54e7ace0..9a8e6bbd 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -8,7 +8,8 @@ end objective ::AbstractVariationalObjective, restructure, λ₀ ::AbstractVector{<:Real}, - n_max_iter ::Int; + n_max_iter ::Int, + objargs...; kwargs... ) @@ -17,7 +18,8 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie optimize( objective ::AbstractVariationalObjective, q, - n_max_iter::Int; + n_max_iter::Int, + objargs...; kwargs... ) @@ -29,36 +31,34 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie - `restruct`: Function that reconstructs the variational approximation from the flattened parameters. - `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`. - `n_max_iter`: Maximum number of iterations. +- `objargs...`: Arguments to be passed to `objective`. +- `kwargs...`: Additional keywoard arguments. (See below.) # Keyword Arguments - `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.) - `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.) - `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.) - `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.) -- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient. +- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.) - `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) - -When resuming from the state of a previous run, use the following keyword arguments: -- `opt_state`: Initial state of the optimizer. -- `obj_state`: Initial state of the objective. +- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.) # Returns - `λ`: Variational parameters optimizing the variational objective. -- `stats`: Statistics gathered during inference. -- `opt_state`: Final state of the optimiser. -- `obj_state`: Final state of the objective. +- `logstats`: Statistics and logs gathered during optimization. +- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. """ function optimize( objective ::AbstractVariationalObjective, restructure, λ₀ ::AbstractVector{<:Real}, - n_max_iter ::Int; - adbackend::AbstractADType, + n_max_iter ::Int, + objargs...; + adbackend ::AbstractADType, optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), rng ::AbstractRNG = default_rng(), show_progress::Bool = true, - opt_state = nothing, - obj_state = nothing, + state ::NamedTuple = NamedTuple(), callback! = nothing, prog = ProgressMeter.Progress( n_max_iter; @@ -66,37 +66,39 @@ function optimize( barlen = 31, showspeed = true, enabled = show_progress - ) + ) ) - λ = copy(λ₀) - opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ) : opt_state - obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state - grad_buf = DiffResults.GradientResult(λ) - stats = NamedTuple[] + λ = copy(λ₀) + opt_st = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ) + obj_st = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure) + grad_buf = DiffResults.GradientResult(λ) + logstats = NamedTuple[] for t = 1:n_max_iter stat = (iteration=t,) - grad_buf, obj_state, stat′ = estimate_gradient( - rng, adbackend, objective, obj_state, λ, restructure, grad_buf) + grad_buf, obj_st, stat′ = estimate_gradient( + rng, adbackend, objective, obj_st, + λ, restructure, grad_buf; objargs... + ) stat = merge(stat, stat′) - g = DiffResults.gradient(grad_buf) - opt_state, λ = Optimisers.update!(opt_state, λ, g) - stat′ = (iteration = t,) - stat = merge(stat, stat′) + g = DiffResults.gradient(grad_buf) + opt_st, λ = Optimisers.update!(opt_st, λ, g) if !isnothing(callback!) - stat′ = callback!(; obj_state, stat, restructure, λ, g) + stat′ = callback!(; stat, restructure, λ, g) stat = !isnothing(stat′) ? merge(stat′, stat) : stat end @debug "Iteration $t" stat... pm_next!(prog, stat) - push!(stats, stat) + push!(logstats, stat) end - λ, map(identity, stats), opt_state, obj_state + state = (opt=opt_st, obj=obj_st) + logstats = map(identity, logstats) + λ, logstats, state end function optimize(objective ::AbstractVariationalObjective, @@ -104,8 +106,8 @@ function optimize(objective ::AbstractVariationalObjective, n_max_iter::Int; kwargs...) λ, restructure = Optimisers.destructure(q₀) - λ, stats, opt_state, obj_state = optimize( + λ, logstats, state = optimize( objective, restructure, λ, n_max_iter; kwargs... ) - restructure(λ), stats, opt_state, obj_state + restructure(λ), logstats, state end diff --git a/test/advi_locscale.jl b/test/advi_locscale.jl index d5250ce8..8d8df1e9 100644 --- a/test/advi_locscale.jl +++ b/test/advi_locscale.jl @@ -50,7 +50,7 @@ using ReTest @testset "convergence" begin Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) - q, stats, _, _ = optimize( + q, stats, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, @@ -69,7 +69,7 @@ using ReTest @testset "determinism" begin rng = Philox4x(UInt64, seed, 8) - q, stats, _, _ = optimize( + q, stats, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, @@ -80,7 +80,7 @@ using ReTest L = q.scale rng_repl = Philox4x(UInt64, seed, 8) - q, stats, _, _ = optimize( + q, stats, _ = optimize( obj, q₀, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, diff --git a/test/optimize.jl b/test/optimize.jl index 2369432c..a8013be2 100644 --- a/test/optimize.jl +++ b/test/optimize.jl @@ -21,7 +21,7 @@ using ReTest optimizer = Optimisers.Adam(1e-2) rng = Philox4x(UInt64, seed, 8) - q_ref, stats_ref, _, _ = optimize( + q_ref, stats_ref, _ = optimize( obj, q₀, T; optimizer, show_progress = false, @@ -34,7 +34,7 @@ using ReTest λ₀, re = Optimisers.destructure(q₀) rng = Philox4x(UInt64, seed, 8) - λ, stats, _, _ = optimize( + λ, stats, _ = optimize( obj, re, λ₀, T; optimizer, show_progress = false, @@ -49,13 +49,14 @@ using ReTest rng = Philox4x(UInt64, seed, 8) test_values = rand(rng, T) - callback!(; stat, obj_state, restructure, λ, g) = begin + callback!(; stat, restructure, λ, g) = begin (test_value = test_values[stat.iteration],) end rng = Philox4x(UInt64, seed, 8) - _, stats, _, _ = optimize( + _, stats, _ = optimize( obj, q₀, T; + optimizer, show_progress = false, rng, adbackend, @@ -63,4 +64,29 @@ using ReTest ) @test [stat.test_value for stat ∈ stats] == test_values end + + @testset "warm start" begin + rng = Philox4x(UInt64, seed, 8) + + T_first = div(T,2) + T_last = T - T_first + + q_first, _, state = optimize( + obj, q₀, T_first; + optimizer, + show_progress = false, + rng, + adbackend + ) + + q, stats, _ = optimize( + obj, q_first, T_last; + optimizer, + show_progress = false, + state, + rng, + adbackend + ) + @test q == q_ref + end end From bc5cfd348afe928bcbc7e714cd46f862e5b796a4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 24 Aug 2023 21:07:22 +0100 Subject: [PATCH 143/144] update documentation for `optimize` --- docs/src/advi.md | 4 ++-- docs/src/started.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/advi.md b/docs/src/advi.md index 2773dda7..f4fe3715 100644 --- a/docs/src/advi.md +++ b/docs/src/advi.md @@ -186,7 +186,7 @@ stl = AVI.ADVI(model, n_montecarlo; entropy = AVI.StickingTheLandingEntropy(), i ```@setup stl n_max_iter = 10^4 -_, stats_cfe, _, _ = AVI.optimize( +_, stats_cfe, _ = AVI.optimize( cfe, q0, n_max_iter; @@ -195,7 +195,7 @@ _, stats_cfe, _, _ = AVI.optimize( optimizer = Optimisers.Adam(1e-3) ); -_, stats_stl, _, _ = AVI.optimize( +_, stats_stl, _ = AVI.optimize( stl, q0, n_max_iter; diff --git a/docs/src/started.md b/docs/src/started.md index e8392fd7..e3e78c35 100644 --- a/docs/src/started.md +++ b/docs/src/started.md @@ -102,7 +102,7 @@ q = AVI.VIMeanFieldGaussian(μ, L) Passing `objective` and the initial variational approximation `q` to `optimize` performs inference. ```@example advi n_max_iter = 10^4 -q, stats, _, _ = AVI.optimize( +q, stats, _ = AVI.optimize( objective, q, n_max_iter; From de4284eeb7c80b29acf1bae77746e523cb6b2602 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Fri, 25 Aug 2023 19:08:10 +0100 Subject: [PATCH 144/144] fix CUDA-compatibility bugs --- src/objectives/elbo/advi.jl | 4 ++-- src/optimize.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/objectives/elbo/advi.jl b/src/objectives/elbo/advi.jl index f9a61d81..354c822c 100644 --- a/src/objectives/elbo/advi.jl +++ b/src/objectives/elbo/advi.jl @@ -88,8 +88,8 @@ function estimate_gradient( rng ::AbstractRNG, adbackend ::AbstractADType, advi ::ADVI, - est_state, - λ ::Vector{<:Real}, + obj_state, + λ ::AbstractVector{<:Real}, restructure, out ::DiffResults.MutableDiffResult ) diff --git a/src/optimize.jl b/src/optimize.jl index 9a8e6bbd..ea2fd5a1 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -71,7 +71,7 @@ function optimize( λ = copy(λ₀) opt_st = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ) obj_st = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure) - grad_buf = DiffResults.GradientResult(λ) + grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) logstats = NamedTuple[] for t = 1:n_max_iter