diff --git a/Project.toml b/Project.toml index cccad6e..170b493 100644 --- a/Project.toml +++ b/Project.toml @@ -7,8 +7,10 @@ version = "1.0.0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -18,14 +20,16 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] Distributions = "0.25" DomainSets = "0.7" +LinearAlgebra = "1.9" Random = "1.9" SpecialFunctions = "2.3" Statistics = "1.9" StatsAPI = "1.7" StatsBase = "0.34" +StaticArrays = "1.6" +LoopVectorization = "0.12" StatsFuns = "1.3" TinyHugeNumbers = "1.0" -LinearAlgebra = "1.9" julia = "1.9" [extras] diff --git a/src/BayesBase.jl b/src/BayesBase.jl index 0465df9..27a0a90 100644 --- a/src/BayesBase.jl +++ b/src/BayesBase.jl @@ -87,6 +87,7 @@ include("prod.jl") include("densities/pointmass.jl") include("densities/function.jl") +include("densities/samplelist.jl") include("densities/mixture.jl") include("densities/factorizedjoint.jl") diff --git a/src/densities/function.jl b/src/densities/function.jl index a1f99a6..1794be5 100644 --- a/src/densities/function.jl +++ b/src/densities/function.jl @@ -11,7 +11,6 @@ getdomain(dist::AbstractContinuousGenericLogPdf) = dist.domain getlogpdf(dist::AbstractContinuousGenericLogPdf) = dist.logpdf BayesBase.value_support(::Type{<:AbstractContinuousGenericLogPdf}) = Continuous -BayesBase.value_support(::AbstractContinuousGenericLogPdf) = Continuous # We throw an error on purpose, since we do not want to use `AbstractContinuousGenericLogPdf` much without approximations # We want to encourage a user to use approximate generic log-pdfs as much as possible instead @@ -92,7 +91,6 @@ function ContinuousUnivariateLogPdf(f::Function) end BayesBase.variate_form(::Type{<:ContinuousUnivariateLogPdf}) = Univariate -BayesBase.variate_form(::ContinuousUnivariateLogPdf) = Univariate function BayesBase.promote_variate_type( ::Type{Univariate}, ::Type{AbstractContinuousGenericLogPdf} @@ -172,7 +170,6 @@ struct ContinuousMultivariateLogPdf{D<:DomainSets.Domain,F} <: end BayesBase.variate_form(::Type{<:ContinuousMultivariateLogPdf}) = Multivariate -BayesBase.variate_form(::ContinuousMultivariateLogPdf) = Multivariate function BayesBase.promote_variate_type( ::Type{Multivariate}, ::Type{AbstractContinuousGenericLogPdf} diff --git a/src/densities/pointmass.jl b/src/densities/pointmass.jl index 4032db3..1abab89 100644 --- a/src/densities/pointmass.jl +++ b/src/densities/pointmass.jl @@ -10,10 +10,10 @@ end getpointmass(distribution::PointMass) = distribution.point getpointmass(point::Union{Real,AbstractArray}) = point -BayesBase.variate_form(::PointMass{T}) where {T<:Real} = Univariate -BayesBase.variate_form(::PointMass{V}) where {T,V<:AbstractVector{T}} = Multivariate -BayesBase.variate_form(::PointMass{M}) where {T,M<:AbstractMatrix{T}} = Matrixvariate -BayesBase.variate_form(::PointMass{U}) where {T,U<:UniformScaling{T}} = Matrixvariate +BayesBase.variate_form(::Type{PointMass{T}}) where {T<:Real} = Univariate +BayesBase.variate_form(::Type{PointMass{V}}) where {T,V<:AbstractVector{T}} = Multivariate +BayesBase.variate_form(::Type{PointMass{M}}) where {T,M<:AbstractMatrix{T}} = Matrixvariate +BayesBase.variate_form(::Type{PointMass{U}}) where {T,U<:UniformScaling{T}} = Matrixvariate function BayesBase.mean(fn::F, distribution::PointMass) where {F<:Function} return fn(mean(distribution)) diff --git a/src/densities/samplelist.jl b/src/densities/samplelist.jl new file mode 100644 index 0000000..39a5d43 --- /dev/null +++ b/src/densities/samplelist.jl @@ -0,0 +1,821 @@ + +import StatsBase: Weights + +using StaticArrays +using LoopVectorization + +export SampleList, SampleListMeta + +abstract type AbstractSampleListSamplingMethod end + +struct BootstrapImportanceSampling <: AbstractSampleListSamplingMethod end + +mutable struct SampleListCache{M, C} + mean :: M + cov :: C + is_mean_cached :: Bool + is_cov_cached :: Bool +end + +SampleListCache(::Type{T}, dims::Tuple{}) where {T} = SampleListCache(zero(T), zero(T), false, false) +SampleListCache(::Type{T}, dims::Tuple{Int}) where {T} = SampleListCache(zeros(T, first(dims)), zeros(T, first(dims), first(dims)), false, false) +SampleListCache(::Type{T}, dims::Tuple{Int, Int}) where {T} = SampleListCache(zeros(T, dims), zeros(T, prod(dims), prod(dims)), false, false) + +is_mean_cached(cache::SampleListCache) = cache.is_mean_cached +is_cov_cached(cache::SampleListCache) = cache.is_cov_cached + +get_mean_storage(cache::SampleListCache) = cache.mean +get_cov_storage(cache::SampleListCache) = cache.cov + +cache_mean!(cache::SampleListCache, mean) = begin + cache.mean = mean + cache.is_mean_cached = true + mean +end +cache_cov!(cache::SampleListCache, cov) = begin + cache.cov = cov + cache.is_cov_cached = true + cov +end + +struct SampleListMeta{W, E, LP, LI} + unnormalisedweights :: W + entropy :: E + logproposal :: LP + logintegrand :: LI +end + +get_unnormalised_weights(meta::SampleListMeta) = meta.unnormalisedweights +get_entropy(meta::SampleListMeta) = meta.entropy +get_logproposal(meta::SampleListMeta) = meta.logproposal +get_logintegrand(meta::SampleListMeta) = meta.logintegrand + +call_logproposal(logproposal::Function, x) = logproposal(x) +call_logproposal(logproposal::Any, x) = logpdf(logproposal, x) + +call_logintegrand(logintegrand::Function, x) = logintegrand(x) +call_logintegrand(logintegrand::Any, x) = logpdf(logintegrand, x) + +""" + SampleList + +Generic distribution represented as a list of weighted samples. + +# Arguments +- `samples::S` +- `weights::W`: optional, equivalent to `fill(1 / N, N)` by default, where `N` is the length of `samples` container +""" +struct SampleList{D, S, W, C, M} + samples :: S + weights :: W + cache :: C + meta :: M + + function SampleList(::Val{D}, samples::S, weights::W, meta::M = nothing) where {D, S, W, M} + @assert div(length(samples), prod(D)) === length(weights) "Invalid sample list samples and weights lengths. `samples` has length $(length(samples)), `weights` has length $(length(weights))" + @assert eltype(samples) <: Number "Invalid eltype of samples container. Should be a subtype of `Number`, but $(eltype(samples)) has been found. Samples should be stored in a linear one dimensional vector even for multivariate and matrixvariate cases." + @assert eltype(weights) <: Number "Invalid eltype of weights container. Should be a subtype of `Number`, but $(eltype(weights)) has been found." + cache = SampleListCache(promote_type(eltype(samples), eltype(weights)), D) + return new{D, S, W, typeof(cache), M}(samples, weights, cache, meta) + end +end + +Base.show(io::IO, sl::SampleList) = sample_list_show(io, variate_form(typeof(sl)), sl) +Base.similar(sl::SampleList{D}) where {D} = SampleList(Val(D), similar(sl.samples), similar(sl.weights)) + +sample_list_show(io::IO, ::Type{Univariate}, sl::SampleList) = print(io, "SampleList(Univariate, ", length(sl), ")") +sample_list_show(io::IO, ::Type{Multivariate}, sl::SampleList) = print(io, "SampleList(Multivariate(", ndims(sl), "), ", length(sl), ")") +sample_list_show(io::IO, ::Type{Matrixvariate}, sl::SampleList) = print(io, "SampleList(Matrixvariate", ndims(sl), ", ", length(sl), ")") + +function SampleList(samples::S) where {S <: AbstractVector} + N = length(samples) + return SampleList(samples, fill(one(deep_eltype(S)) / N, N)) +end + +function SampleList(samples::S, weights::W, meta::M = nothing) where {S, W, M} + nsamples = length(samples) + @assert nsamples !== 0 "Empty samples list" + @assert sum(weights) ≈ one(eltype(weights)) "Weights must sum up to one. sum(weights) = $(sum(weights))" + D = size(first(samples)) + return SampleList(Val(D), sample_list_linearize(samples, nsamples, prod(D)), weights, meta) +end + +const DEFAULT_SAMPLE_LIST_N_SAMPLES = 5000 + +## Utility functions + +BayesBase.paramfloattype(sl::SampleList) = promote_type(eltype(sl.samples), eltype(sl.weights)) + +BayesBase.convert_paramfloattype(::Type{T}, sl::SampleList{D}) where {T, D} = SampleList(Val(D), convert_paramfloattype(T, sl.samples), convert_paramfloattype(T, sl.weights), sl.meta) + +Base.eltype(::Type{<:SampleList{D, S, W}}) where {D, S, W} = Tuple{sample_list_eltype(SampleList, D, S), eltype(W)} + +BayesBase.sampletype(::SampleList{D, S}) where {D, S} = sample_list_eltype(SampleList, D, S) + +sample_list_eltype(::Type{SampleList}, ndims::Tuple{}, ::Type{S}) where {S} = eltype(S) +sample_list_eltype(::Type{SampleList}, ndims::Tuple{Int}, ::Type{S}) where {S} = SVector{ndims[1], eltype(S)} +sample_list_eltype(::Type{SampleList}, ndims::Tuple{Int, Int}, ::Type{S}) where {S} = SMatrix{ndims[1], ndims[2], eltype(S), ndims[1] * ndims[2]} + +BayesBase.deep_eltype(::Type{<:SampleList{D, S}}) where {D, S} = eltype(S) + +## Variate forms + +BayesBase.variate_form(::Type{<:SampleList{D}}) where {D} = sample_list_variate_form(D) + +sample_list_variate_form(::Tuple{}) = Univariate +sample_list_variate_form(::Tuple{Int}) = Multivariate +sample_list_variate_form(::Tuple{Int, Int}) = Matrixvariate + +## Getters + +get_weights(sl::SampleList) = get_linear_weights(sl) +get_samples(sl::SampleList) = SamplesOnlyIterator(sl) + +get_linear_weights(sl::SampleList) = sl.weights +get_linear_samples(sl::SampleList) = sl.samples +get_cache(sl::SampleList) = sl.cache +get_meta(sl::SampleList) = sample_list_check_meta(sl.meta) +is_meta_present(sl::SampleList) = sl.meta !== nothing + +get_data(sl::SampleList) = (length(sl), get_linear_samples(sl), get_linear_weights(sl)) + +sample_list_check_meta(meta::Any) = meta +sample_list_check_meta(meta::Nothing) = error("SampleList object has not associated meta information with it.") + +get_unnormalised_weights(sl::SampleList) = get_unnormalised_weights(get_meta(sl)) +get_entropy(sl::SampleList) = get_entropy(get_meta(sl)) +get_logproposal(sl::SampleList) = get_logproposal(get_meta(sl)) +get_logintegrand(sl::SampleList) = get_logintegrand(get_meta(sl)) + +call_logproposal(sl::SampleList, x) = call_logproposal(get_logproposal(sl), x) +call_logintegrand(sl::SampleList, x) = call_logintegrand(get_logintegrand(sl), x) + +Base.length(sl::SampleList) = div(length(get_linear_samples(sl)), prod(ndims(sl))) +Base.ndims(sl::SampleList) = sample_list_ndims(variate_form(typeof(sl)), sl) +Base.size(sl::SampleList) = (length(sl),) + +sample_list_ndims(::Type{Univariate}, sl::SampleList{D}) where {D} = 1 +sample_list_ndims(::Type{Multivariate}, sl::SampleList{D}) where {D} = first(D) +sample_list_ndims(::Type{Matrixvariate}, sl::SampleList{D}) where {D} = D + +## Statistics + +# Returns a zeroed container for mean +function sample_list_zero_element(sl::SampleList) + T = promote_type(eltype(get_linear_weights(sl)), eltype(get_linear_samples(sl))) + return sample_list_zero_element(variate_form(typeof(sl)), T, sl) +end + +sample_list_zero_element(::Type{Univariate}, ::Type{T}, sl::SampleList) where {T} = zero(T) +sample_list_zero_element(::Type{Multivariate}, ::Type{T}, sl::SampleList) where {T} = zeros(T, ndims(sl)) +sample_list_zero_element(::Type{Matrixvariate}, ::Type{T}, sl::SampleList) where {T} = zeros(T, ndims(sl)) + +# Generic mean_cov + +BayesBase.mean_cov(sl::SampleList) = sample_list_mean_cov(sl, Val(true)) +BayesBase.mean_var(sl::SampleList) = sample_list_mean_var(variate_form(typeof(sl)), sl) + +## + +BayesBase.mean(sl::SampleList) = sample_list_mean(sl, Val(true)) +BayesBase.var(sl::SampleList) = last(mean_var(sl)) +BayesBase.cov(sl::SampleList) = last(mean_cov(sl)) +BayesBase.invcov(sl::SampleList) = inv(cov(sl)) +BayesBase.std(sl::SampleList) = sqrt(cov(sl)) +BayesBase.logdetcov(sl::SampleList) = logdet(cov(sl)) + +Base.precision(sl::SampleList) = invcov(sl) + +function BayesBase.mean_precision(sl::SampleList) + μ, Σ = mean_cov(sl) + return μ, inv(Σ) +end + +function BayesBase.weightedmean_precision(sl::SampleList) + μ, Λ = mean_precision(sl) + return Λ * μ, Λ +end + +BayesBase.weightedmean(sl::SampleList) = first(weightedmean_precision(sl)) + +BayesBase.mean(::typeof(log), sl::SampleList) = sample_list_logmean(variate_form(typeof(sl)), sl) +BayesBase.mean(::typeof(xtlog), sl::SampleList) = sample_list_meanlogmean(variate_form(typeof(sl)), sl) +BayesBase.mean(::typeof(mirrorlog), sl::SampleList) = sample_list_mirroredlogmean(variate_form(typeof(sl)), sl) + +# Generic version of the mean function for arbitrary `f` +function BayesBase.mean(f::F, sl::SampleList) where {F} + return mapreduce(+, zip(get_weights(sl), get_samples(sl))) do (weight, sample) + return weight * f(sample) + end +end + +## + +# Differential entropy for SampleList +# Entropy is pre-computed during computation of the marginal in `approximate_prod_with_sample_list` function + +BayesBase.entropy(sl::SampleList) = get_entropy(get_meta(sl)) + +# `entropy` for the `SampleList` is not defined if `meta` is of type `Nothing` +function Distributions.entropy(::SampleList{D, S, W, C, Nothing}) where {D, S, W, C} + error("`entropy` for the `SampleList` is not defined if `meta` is of type `Nothing`") +end + +## + +BayesBase.vague(::Type{SampleList}; nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) = sample_list_vague(Univariate, nsamples) +BayesBase.vague(::Type{SampleList}, dims::Int; nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) = sample_list_vague(Multivariate, dims, nsamples) +BayesBase.vague(::Type{SampleList}, dims::Tuple{Int, Int}; nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) = sample_list_vague(Matrixvariate, dims, nsamples) +BayesBase.vague(::Type{SampleList}, dim1::Int, dim2::Int; nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) = sample_list_vague(Matrixvariate, (dim1, dim2), nsamples) + +## + +BayesBase.rand(samplelist::SampleList) = rand(Random.default_rng(), samplelist) +BayesBase.rand(samplelist::SampleList, n::Integer) = rand(Random.default_rng(), samplelist, n) + +function BayesBase.rand(rng::AbstractRNG, samplelist::SampleList) + return rand(rng, get_samples(samplelist)) +end + +function BayesBase.rand(rng::AbstractRNG, samplelist::SampleList, n::Integer) + return rand(rng, get_samples(samplelist), n) +end + +## + +sample_list_default_prod_strategy() = BootstrapImportanceSampling() + +## prod related stuff +function approximate_prod_with_sample_list(x, y, nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) + return approximate_prod_with_sample_list(Random.GLOBAL_RNG, x, y, nsamples) +end + +function approximate_prod_with_sample_list(rng::AbstractRNG, x, y, nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) + return approximate_prod_with_sample_list(rng, sample_list_default_prod_strategy(), x, y, nsamples) +end + +# `x` is proposal distribution +# `y` is integrand distribution +function approximate_prod_with_sample_list(rng::AbstractRNG, ::BootstrapImportanceSampling, x::Any, y::Any, nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) + @assert nsamples >= 1 "Number of samples should be greater than 1" + + xlogpdf, xsample = logpdf_sampling_optimized(x) + ylogpdf, ysample = logpdf_sampling_optimized(y) + + T = promote_samplefloattype(x, y) + U = variate_form(typeof(x)) + xsize = size(x) + preallocated = preallocate_samples(T, xsize, nsamples) + samples = rand!(rng, xsample, reshape(preallocated, (xsize..., nsamples))) + + raw_weights = Vector{T}(undef, nsamples) # un-normalised + norm_weights = Vector{T}(undef, nsamples) # normalised + + H_x = zero(T) + weights_sum = zero(T) + + for i in 1:nsamples + # Static indexing from reshaped array + sample_i = static_getindex(U, xsize, samples, i) + # Apply log-pdf functions to the samples + log_sample_x = logpdf(xlogpdf, sample_i) + log_sample_y = logpdf(ylogpdf, sample_i) + + raw_weight = exp(log_sample_y) + + raw_weights[i] = raw_weight + norm_weights[i] = raw_weight # will be renormalised later + + weights_sum += raw_weight + H_x += raw_weight * (log_sample_x + log_sample_y) + end + + # Normalise weights + @turbo for i in 1:nsamples + norm_weights[i] /= weights_sum + end + + # Renormalise H_x + H_x /= weights_sum + + # Compute the separate contributions to the entropy + H_y = log(weights_sum) - log(nsamples) + H_x = -H_x + + entropy = H_x + H_y + + # Inform next step about the proposal and integrand to be used in entropy calculation in smoothing + logproposal = xlogpdf + logintegrand = ylogpdf + + meta = SampleListMeta(raw_weights, entropy, logproposal, logintegrand) + + return SampleList(Val(xsize), preallocated, norm_weights, meta) +end + +function approximate_prod_with_sample_list(rng::AbstractRNG, ::AbstractSampleListSamplingMethod, x::SampleList, y::SampleList, nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) + error("Unsupported SampleList × SampleList prod operation.") +end + +function approximate_prod_with_sample_list(rng::AbstractRNG, method::AbstractSampleListSamplingMethod, x::SampleList, y::Any, nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) + return approximate_prod_with_sample_list(rng, method, y, x, nsamples) +end + +# prod of a pdf (or distribution) message and a SampleList message +# this function is capable to calculate entropy with SampleList messages in VMP setting +function approximate_prod_with_sample_list(rng::AbstractRNG, ::BootstrapImportanceSampling, x::Any, y::SampleList{D}, nsamples::Int = DEFAULT_SAMPLE_LIST_N_SAMPLES) where {D} + + # TODO: In principle it is possible to implement different prod approximation for different nsamples + # TODO: This feature would be probably super rare in use so lets postpone it and mark as todo + @assert length(y) === nsamples "Unsupported SampleList prod approximation: nsamples should match" + + # TODO: Is it possible to support different variate forms here? + @assert variate_form(typeof(x)) === variate_form(typeof(y)) "Unsupported SampleList prod approximation: variate forms should match" + + # Suppose in the previous time step m1(pdf) and m2(pdf) messages collided. + # The resulting collision m3 (sampleList) = m1*m2 is supposed to carry + # the proposal (m1) and integrand (m2) distributions. m1 is the message from which + # the samples are drawn. m2 is the message on which the samples are evaluated and + # weights are calculated. In case Particle Filtering (BP), entropy will not be calculated + # and in the first step there won't be any integrand information. + + xlogpdf = logpdf_optimized(x) + + log_integrand = if is_meta_present(y) && get_logintegrand(y) !== nothing + # recall that we are calculating m3*m4. If m3 consists of integrand information + # update it: new_integrand = m2*m3. This allows us to collide arbitrary number of beliefs + # to approximate posterior and yet estimate the entropy. + let y_logintegrand = get_logintegrand(y) + (sample) -> call_logintegrand(y_logintegrand, sample) + logpdf(xlogpdf, sample) + end + else + # If there is no integrand information before, set it to m4 + xlogpdf + end + + samples = get_samples(y) # samples come from proposal (m1) + weights = get_weights(y) + + # Resulting samples and weights will go here + rcontainer = similar(y) + # vec here just to convert unusual containers into an array just in case + # does nothing if `get_weights` returns an array + rsamples, rweights = get_samples(rcontainer), vec(get_weights(rcontainer)) + + rweights_raw = similar(rweights) + rweights_prod_sum = zero(eltype(get_weights(y))) + + H_x = zero(eltype(rweights_raw)) + + # Compute sample weights + @inbounds for i in 1:nsamples + log_sample_x = logpdf(xlogpdf, samples[i]) # evaluate samples in logm4, i.e. logm4(s) + raw_weight = exp(log_sample_x) # m4(s) + raw_weight_prod = raw_weight * weights[i] # update the weights of posterior w_unnormalized = m4(s)*w_prev + + rweights_raw[i] = raw_weight + rweights[i] = raw_weight_prod + + rweights_prod_sum += raw_weight_prod + + H_x += raw_weight_prod * log_sample_x + end + + H_x /= rweights_prod_sum + + # Normalize prod weights + @turbo for i in 1:nsamples + rweights[i] /= rweights_prod_sum + end + + # Effective number of particles (Base.Generator is allocation free version of map) + # neff = sum(rweights .^ 2) + neff = 1 / mapreduce(abs2, +, rweights) + + # Resample and readjust entropy approximation if required + if neff < nsamples / 10 + sample!(rng, samples, Weights(weights), rsamples) + fillv = one(eltype(rweights)) / nsamples + fill!(rweights, fillv) + # Readjust H_x after resampling + H_x = zero(eltype(rweights_raw)) + for i in 1:nsamples + log_sample_x = logpdf(xlogpdf, rsamples[i]) + H_x += (exp(log_sample_x) * fillv) * log_sample_x + end + # H_x /= 1 rweights are normalized at this point, no need to track of its sum + else + # Just copy the existing samples instead + copyto!(get_linear_samples(rsamples), get_linear_samples(samples)) + end + + meta = if is_meta_present(y) && get_logproposal(y) !== nothing && get_unnormalised_weights(y) !== nothing + y_unnormalised_weights = get_unnormalised_weights(y) + r_unnormalised_weights = similar(y_unnormalised_weights) + r_unnormalised_weights_sum = zero(eltype(r_unnormalised_weights)) + @inbounds for i in 1:nsamples + r_unnormalised_weights_prod = y_unnormalised_weights[i] * rweights_raw[i] + r_unnormalised_weights[i] = r_unnormalised_weights_prod + r_unnormalised_weights_sum += r_unnormalised_weights_prod + + H_x += rweights[i] * (call_logproposal(y, rsamples[i]) + log(y_unnormalised_weights[i])) + end + + H_y = log(r_unnormalised_weights_sum) - log(nsamples) + H_x = -H_x + + entropy = H_x + H_y + + SampleListMeta(r_unnormalised_weights, entropy, get_logproposal(y), log_integrand) + else + SampleListMeta(nothing, nothing, nothing, log_integrand) + end + + return SampleList(Val(D), get_linear_samples(rsamples), rweights, meta) +end + +############################################################################################ +## Everything below this comment is a low-level implementation of SampleList +## It consists of various routines to compute statistics from internal list implementation +## Here we also implement iteration utilities +## ########################################################################################## + +## Lowlevel implementation below... + +@inline static_getindex(::Type{Univariate}, ndims::Tuple{}, samples, i) = samples[i] +@inline static_getindex(::Type{Multivariate}, ndims::Tuple{Int}, samples, i) = view(samples, :, i) +@inline static_getindex(::Type{Matrixvariate}, ndims::Tuple{Int, Int}, samples, i) = view(samples, :, :, i) + +## Preallocation utilities + +preallocate_samples(::Type{T}, dims::Tuple, length::Int) where {T} = Vector{T}(undef, length * prod(dims)) + +## Linearization functions + +# Here we cast an array of arrays into a single flat array of floats for better performance +# There is a package for this called ArraysOfArrays.jl, but the performance of handwritten version is way better +# We provide custom optimized mean/cov function for our implementation with LoopVectorization.jl package +function sample_list_linearize end + +function sample_list_linearize(samples::AbstractVector{T}, nsamples, size) where {T <: Number} + return samples +end + +function sample_list_linearize(samples::AbstractVector, nsamples, size) + T = deep_eltype(samples) + alloc = Vector{T}(undef, nsamples * size) + for i in 1:nsamples + copyto!(alloc, (i - 1) * size + 1, samples[i], 1, size) + end + return alloc +end + +## Cache utilities + +sample_list_mean(sl::SampleList, cached) = sample_list_mean(variate_form(typeof(sl)), sl, cached) +sample_list_mean_cov(sl::SampleList, cached) = sample_list_mean_cov(variate_form(typeof(sl)), sl, cached) + +# Cache ignoring versions +function sample_list_mean(::Type{U}, sl::SampleList, ::Val{false}) where {U} + return sample_list_mean!(fill!(similar(get_mean_storage(get_cache(sl))), zero(deep_eltype(sl))), U, sl) +end + +function sample_list_mean_cov(::Type{U}, sl::SampleList, ::Val{false}) where {U} + mean = sample_list_mean(U, sl, Val(false)) + cov = fill!(similar(get_cov_storage(get_cache(sl))), zero(deep_eltype(sl))) + sample_list_covm!(cov, mean, U, sl) + return mean, cov +end + +# By default we try to save mean in an internal cache +function sample_list_mean(::Type{U}, sl::SampleList, ::Val{true}) where {U} + cache = get_cache(sl) + mean = get_mean_storage(cache) + # If no cache present, compute and save + if !is_mean_cached(cache) + mean = cache_mean!(cache, sample_list_mean!(mean, U, sl)) + end + return mean +end + +# By default we try to save cov in an internal cache +function sample_list_mean_cov(::Type{U}, sl::SampleList, ::Val{true}) where {U} + cache = get_cache(sl) + mean = sample_list_mean(U, sl, Val(true)) + cov = get_cov_storage(cache) + # If no cache present, compute and save + if !is_cov_cached(cache) + cov = cache_cov!(cache, sample_list_covm!(cov, mean, U, sl)) + end + return (mean, cov) +end + +## Specific implementations + +# Compute mean in a preallocated container and return it +function sample_list_mean! end +# Compute covariance with known mean in a preallocated container and return it +function sample_list_covm! end +# Compute mean and variance +function sample_list_mean_var end +# Compute E[log(x)] +function sample_list_logmean end +# Compute E[xlog(x)] +function sample_list_meanlogmean end +# Compute E[log(1 - x)] +function sample_list_mirroredlogmean end +# Return vague weak-informative sample list +function sample_list_vague end + +## Univariate + +function sample_list_mean!(μ, ::Type{Univariate}, sl::SampleList) + n, samples, weights = get_data(sl) + @turbo for i in 1:n + μ += weights[i] * samples[i] + end + return μ +end + +function sample_list_covm!(σ², μ, ::Type{Univariate}, sl::SampleList) + n, samples, weights = get_data(sl) + @turbo for i in 1:n + σ² += weights[i] * abs2(samples[i] - μ) + end + σ² = (n / (n - 1)) * σ² + return σ² +end + +function sample_list_mean_var(::Type{Univariate}, sl::SampleList) + return sample_list_mean_cov(Univariate, sl, Val(true)) +end + +function sample_list_logmean(::Type{Univariate}, sl::SampleList) + n, samples, weights = get_data(sl) + logμ = sample_list_zero_element(sl) + @turbo for i in 1:n + logμ += weights[i] * log(samples[i]) + end + return logμ +end + +function sample_list_meanlogmean(::Type{Univariate}, sl::SampleList) + n, samples, weights = get_data(sl) + μlogμ = sample_list_zero_element(sl) + @turbo for i in 1:n + μlogμ += weights[i] * samples[i] * log(samples[i]) + end + return μlogμ +end + +function sample_list_mirroredlogmean(::Type{Univariate}, sl::SampleList) + n, samples, weights = get_data(sl) + @assert all(0 .<= samples .< 1) "mean of `mirrorlog` of variable does not apply to variables outside of the range [0, 1]" + mirμ = sample_list_zero_element(sl) + @turbo for i in 1:n + mirμ += weights[i] * log(1 - samples[i]) + end + return mirμ +end + +function sample_list_vague(::Type{Univariate}, nsamples::Int) + targetdist = Uniform(-100, 100) + preallocated = preallocate_samples(Float64, (), nsamples) + rand!(targetdist, preallocated) + return SampleList(Val(()), preallocated, fill(one(Float64) / nsamples, nsamples), nothing) +end + +## Multivariate + +function sample_list_mean!(μ, ::Type{Multivariate}, sl::SampleList) + n, samples, weights = get_data(sl) + k = length(μ) + @turbo for i in 1:n, j in 1:k + μ[j] += (weights[i] * samples[(i - 1) * k + j]) + end + return μ +end + +function sample_list_covm!(Σ, μ, ::Type{Multivariate}, sl::SampleList) + n, samples, weights = get_data(sl) + tmp = similar(μ) + k = length(tmp) + + @inbounds for i in 1:n + for j in 1:k + tmp[j] = samples[(i - 1) * k + j] - μ[j] + end + # Fast equivalent of Σ += w .* (tmp * tmp') + for h in 1:k, l in 1:k + Σ[(h - 1) * k + l] += weights[i] * tmp[h] * tmp[l] + end + end + s = n / (n - 1) + @turbo for i in 1:length(Σ) + Σ[i] *= s + end + return Σ +end + +function sample_list_mean_var(::Type{Multivariate}, sl::SampleList) + μ, Σ = sample_list_mean_cov(Multivariate, sl, Val(true)) + return μ, diag(Σ) +end + +function sample_list_logmean(::Type{Multivariate}, sl::SampleList) + n, samples, weights = get_data(sl) + logμ = sample_list_zero_element(sl) + k = length(logμ) + @turbo for i in 1:n, j in 1:k + logμ[j] += (weights[i] * log(samples[(i - 1) * k + j])) + end + return logμ +end + +function sample_list_meanlogmean(::Type{Multivariate}, sl::SampleList) + n, samples, weights = get_data(sl) + μlogμ = sample_list_zero_element(sl) + k = length(μlogμ) + @turbo for i in 1:n, j in 1:k + cs = samples[(i - 1) * k + j] + μlogμ[j] += (weights[i] * cs * log(cs)) + end + return μlogμ +end + +function sample_list_vague(::Type{Multivariate}, dims::Int, nsamples::Int) + targetdist = Uniform(-100, 100) + preallocated = preallocate_samples(Float64, (dims,), nsamples) + rand!(targetdist, preallocated) + return SampleList(Val((dims,)), preallocated, fill(one(Float64) / nsamples, nsamples), nothing) +end + +## Matrixvariate + +function sample_list_mean!(μ, ::Type{Matrixvariate}, sl::SampleList) + n, samples, weights = get_data(sl) + k = length(μ) + @turbo for i in 1:n, j in 1:k + μ[j] += (weights[i] * samples[(i - 1) * k + j]) + end + return μ +end + +function sample_list_covm!(Σ, μ, ::Type{Matrixvariate}, sl::SampleList) + n, samples, weights = get_data(sl) + k = length(μ) + rμ = reshape(μ, k) + tmp = similar(rμ) + @inbounds for i in 1:n + for j in 1:k + tmp[j] = samples[(i - 1) * k + j] - μ[j] + end + # Fast equivalent of Σ += w .* (tmp * tmp') + for h in 1:k, l in 1:k + Σ[(h - 1) * k + l] += weights[i] * tmp[h] * tmp[l] + end + end + s = n / (n - 1) + @turbo for i in 1:length(Σ) + Σ[i] *= s + end + return Σ +end + +function sample_list_mean_var(::Type{Matrixvariate}, sl::SampleList) + μ, Σ = sample_list_mean_cov(Matrixvariate, sl, Val(true)) + return μ, reshape(diag(Σ), size(μ)) +end + +function sample_list_logmean(::Type{Matrixvariate}, sl::SampleList) + n, samples, weights = get_data(sl) + logμ = sample_list_zero_element(sl) + k = length(logμ) + @turbo for i in 1:n, j in 1:k + logμ[j] += (weights[i] * log(samples[(i - 1) * k + j])) + end + return logμ +end + +function sample_list_meanlogmean(::Type{Matrixvariate}, sl::SampleList) + n, samples, weights = get_data(sl) + μlogμ = sample_list_zero_element(sl) + k = length(μlogμ) + @turbo for i in 1:n, j in 1:k + cs = samples[(i - 1) * k + j] + μlogμ[j] += (weights[i] * cs * log(cs)) + end + return μlogμ +end + +function sample_list_vague(::Type{Matrixvariate}, dims::Tuple{Int, Int}, nsamples::Int) + targetdist = Uniform(-100, 100) + preallocated = preallocate_samples(Float64, dims, nsamples) + rand!(targetdist, preallocated) + return SampleList(Val(dims), preallocated, fill(one(Float64) / nsamples, nsamples), nothing) +end + +## Array operations, broadcasting and mapping + +struct SamplesOnlyIterator{T, L} <: AbstractVector{T} + samplelist::L + + function SamplesOnlyIterator(samplelist::L) where {L <: SampleList} + return new{samples_type(eltype(samplelist)), L}(samplelist) + end +end + +samples_type(::Type{T}) where {L, R, T <: Tuple{L, R}} = L + +@inline Base.size(iter::SamplesOnlyIterator) = (length(iter.samplelist),) +@inline Base.getindex(iter::SamplesOnlyIterator, i::Int) = first(getindex(iter.samplelist, i)) + +@inline function Base.setindex!(iter::SamplesOnlyIterator, v, i::Int) + samples = get_linear_samples(iter.samplelist) + sample_len = prod(ndims(iter.samplelist)) + left = (i - 1) * sample_len + 1 + right = left + sample_len - 1 + copyto!(view(samples, left:right), v) + v +end + +get_linear_samples(iter::SamplesOnlyIterator) = get_linear_samples(iter.samplelist) + +## + +Base.iterate(sl::SampleList) = (sl[1], 2) +Base.iterate(sl::SampleList, state::Int) = state <= length(sl) ? (sl[state], state + 1) : nothing + +@inline Base.getindex(sl::SampleList, i::Int) = sample_list_get_index(variate_form(typeof(sl)), ndims(sl), sl, i) + +@inline function sample_list_get_index(::Type{Univariate}, ndims, sl, i) + return (get_linear_samples(sl)[i], get_linear_weights(sl)[i]) +end + +@inline function sample_list_get_index(::Type{Multivariate}, ndims, sl, i) + samples = get_linear_samples(sl) + left = (i - 1) * ndims + 1 + right = left + ndims - 1 + # ndims is compile-time here + return (SVector{ndims}(view(samples, left:right)), get_linear_weights(sl)[i]) +end + +@inline function sample_list_get_index(::Type{Matrixvariate}, ndims, sl, i) + p = prod(ndims) + samples = get_linear_samples(sl) + left = (i - 1) * p + 1 + right = left + p - 1 + # ndims are compile-time here + return (SMatrix{ndims[1], ndims[2]}(reshape(view(samples, left:right), ndims)), get_linear_weights(sl)[i]) +end + +## Transformation routines + +transform_samples(f::Function, sl::SampleList) = sample_list_transform_samples(variate_form(typeof(sl)), f, sl) + +@inline input_for_transform(::Type{Univariate}, samples, size, left, right) = samples[left] +@inline input_for_transform(::Type{Multivariate}, samples, size, left, right) = SVector{size}(view(samples, left:right)) +@inline input_for_transform(::Type{Matrixvariate}, samples, size, left, right) = SMatrix{size[1], size[2]}(reshape(view(samples, left:right), size)) + +function sample_list_transform_samples(::Type{U}, f::Function, sl::SampleList) where {U} + n, samples, weights = get_data(sl) + input_size = ndims(sl) + input_len = prod(input_size) + + # Here we simulate an original implementation of map function from Julia Base + # Trick here is to compute the first value so compiler may infer the actual output type + # Later on output_size and output_len are compile-time constants (given that `f` is type-stable) + first_item = f(input_for_transform(U, samples, input_size, 1, input_len)) + + # After computing first value compiler knows the output type and size of this type + output_size = size(first_item) + output_len = prod(output_size) + + preallocated = preallocate_samples(promote_type(eltype(first_item), eltype(samples)), output_size, n) + copyto!(view(preallocated, 1:output_len), first_item) + + # We then compute all values from 2 to n into a preallocated buffer + @views for i in 2:n + input_left = (i - 1) * input_len + 1 + input_right = input_left + input_len - 1 + output_left = (i - 1) * output_len + 1 + output_right = output_left + output_len - 1 + # We use static matrix size to ensure that we do not allocate extra memory on a heap + # Instead we try to do all computations on stack as much as possible + # If `f` function is "bad" and still allocates matrices this optimisation doesn't really work well + copyto!(preallocated[output_left:output_right], f(input_for_transform(U, samples, input_size, input_left, input_right))) + end + + # if `f` is type stable and returns StaticArray `output_size` is a compile-time constant + return SampleList(Val(output_size), preallocated, weights) +end + +function transform_weights!(f::Function, sl::SampleList) + n, _, weights = get_data(sl) + map!(f, weights, weights) + norm = sum(weights) + @turbo for i in 1:n + weights[i] /= norm + end + return sl +end diff --git a/src/prod.jl b/src/prod.jl index 5b21bbc..32fadfd 100644 --- a/src/prod.jl +++ b/src/prod.jl @@ -237,8 +237,6 @@ function BayesBase.logpdf(product::ProductOf, x) return logpdf(getleft(product), x) + logpdf(getright(product), x) end -BayesBase.variate_form(::P) where {P<:ProductOf} = variate_form(P) - function BayesBase.variate_form(::Type{ProductOf{L,R}}) where {L,R} return _check_product_variate_form(variate_form(L), variate_form(R)) end @@ -253,8 +251,6 @@ function _check_product_variate_form( ) end -BayesBase.value_support(::P) where {P<:ProductOf} = value_support(P) - function BayesBase.value_support(::Type{ProductOf{L,R}}) where {L,R} return _check_product_value_support(value_support(L), value_support(R)) end @@ -397,10 +393,7 @@ function BayesBase.samplefloattype(product::LinearizedProductOf) end BayesBase.variate_form(::Type{<:LinearizedProductOf{F}}) where {F} = variate_form(F) -BayesBase.variate_form(::LinearizedProductOf{F}) where {F} = variate_form(F) - BayesBase.value_support(::Type{<:LinearizedProductOf{F}}) where {F} = value_support(F) -BayesBase.value_support(::LinearizedProductOf{F}) where {F} = value_support(F) function Base.show(io::IO, product::LinearizedProductOf{F}) where {F} return print(io, "LinearizedProductOf(", F, ", length = ", product.length, ")") diff --git a/test/densities/pointmass_tests.jl b/test/densities/pointmass_tests.jl index a3b8a52..298a37e 100644 --- a/test/densities/pointmass_tests.jl +++ b/test/densities/pointmass_tests.jl @@ -6,7 +6,7 @@ scalar = rand(T) dist = PointMass(scalar) - @test variate_form(dist) === Univariate + @test variate_form(typeof(dist)) === Univariate @test_throws BoundsError dist[2] @test_throws BoundsError dist[2, 2] @@ -56,7 +56,7 @@ end vector = rand(T, N) dist = PointMass(vector) - @test variate_form(dist) === Multivariate + @test variate_form(typeof(dist)) === Multivariate @test dist[2] === vector[2] @test dist[3] === vector[3] @test_throws BoundsError dist[N + 1] @@ -114,7 +114,7 @@ end matrix = rand(T, N, N) dist = PointMass(matrix) - @test variate_form(dist) === Matrixvariate + @test variate_form(typeof(dist)) === Matrixvariate @test dist[2] === matrix[2] @test dist[3] === matrix[3] @test dist[3, 3] === matrix[3, 3] @@ -179,7 +179,7 @@ end matrix = convert(T, 5) * I dist = PointMass(matrix) - @test variate_form(dist) === Matrixvariate + @test variate_form(typeof(dist)) === Matrixvariate @test dist[2, 1] == zero(T) @test dist[3, 1] == zero(T) @test dist[3, 3] === matrix[3, 3] diff --git a/test/densities/samplelist_tests.jl b/test/densities/samplelist_tests.jl new file mode 100644 index 0000000..c0b4998 --- /dev/null +++ b/test/densities/samplelist_tests.jl @@ -0,0 +1,604 @@ +@testitem "SampleList: Internal functions" begin + import BayesBase: sample_list_zero_element + + @test sample_list_zero_element(SampleList([1.0, 1.0])) === 0.0 + @test sample_list_zero_element(SampleList([[1.0, 1.0], [1.0, 1.0]])) == [0.0, 0.0] + @test sample_list_zero_element(SampleList([[1.0 1.0; 1.0 1.0], [1.0 1.0; 1.0 1.0]])) == + [0.0 0.0; 0.0 0.0] + @test sample_list_zero_element( + SampleList([[1.0 1.0 1.0; 1.0 1.0 1.0], [1.0 1.0 1.0; 1.0 1.0 1.0]]) + ) == [0.0 0.0 0.0; 0.0 0.0 0.0] + @test sample_list_zero_element(SampleList([[1.0; 1.0], [1.0; 1.0]])) == [0.0; 0.0] +end + +@testitem "SampleList: Constructor" begin + using StableRNGs, StaticArrays + + import BayesBase: + deep_eltype, + get_samples, + get_weights, + sample_list_zero_element, + get_meta, + is_meta_present, + get_unnormalised_weights, + get_entropy, + get_logproposal, + get_logintegrand, + call_logproposal, + call_logintegrand, + transform_samples, + transform_weights!, + approximate_prod_with_sample_list + + rng = StableRNG(1234) + + for N in [5, 10, 100], type in [Float64, Float32, BigFloat] + scalar_samples = rand(rng, type, N) + scalar_samplelist = SampleList(scalar_samples) + + @test collect(get_samples(scalar_samplelist)) == first.(collect(scalar_samplelist)) + @test collect(get_weights(scalar_samplelist)) == fill(one(type) / N, N) + @test deep_eltype(scalar_samplelist) === type + @test eltype(scalar_samplelist) === Tuple{type,type} + @test eltype(get_weights(scalar_samplelist)) === type + @test variate_form(typeof(scalar_samplelist)) === Univariate + @test is_meta_present(scalar_samplelist) === false + + scalar_weights = rand(rng, type, N) + + @test_throws AssertionError SampleList(scalar_samples, scalar_weights) + scalar_weights ./= sum(scalar_weights) + scalar_samplelist = SampleList(scalar_samples, scalar_weights) + + @test collect(get_samples(scalar_samplelist)) == first.(collect(scalar_samplelist)) + @test collect(get_weights(scalar_samplelist)) == scalar_weights + @test deep_eltype(scalar_samplelist) === type + @test eltype(scalar_samplelist) === Tuple{type,type} + @test eltype(get_weights(scalar_samplelist)) === type + @test variate_form(typeof(scalar_samplelist)) === Univariate + @test is_meta_present(scalar_samplelist) === false + + vector_samples = [rand(rng, type, 2) for _ in 1:N] + vector_samplelist = SampleList(vector_samples) + + @test collect(get_samples(vector_samplelist)) == first.(collect(vector_samplelist)) + @test collect(get_weights(vector_samplelist)) == fill(one(type) / N, N) + @test deep_eltype(vector_samplelist) === type + @test eltype(vector_samplelist) === Tuple{SVector{2,type},type} + @test eltype(get_weights(vector_samplelist)) === type + @test variate_form(typeof(vector_samplelist)) === Multivariate + @test is_meta_present(vector_samplelist) === false + + vector_weights = rand(rng, type, N) + @test_throws AssertionError SampleList(vector_samples, vector_weights) + vector_weights ./= sum(vector_weights) + vector_samplelist = SampleList(vector_samples, vector_weights) + + @test collect(get_samples(vector_samplelist)) == first.(collect(vector_samplelist)) + @test collect(get_weights(vector_samplelist)) == vector_weights + @test deep_eltype(vector_samplelist) === type + @test eltype(vector_samplelist) === Tuple{SVector{2,type},type} + @test eltype(get_weights(vector_samplelist)) === type + @test variate_form(typeof(vector_samplelist)) === Multivariate + @test is_meta_present(vector_samplelist) === false + + matrix_samples = [rand(rng, type, 2, 2) for _ in 1:N] + matrix_samplelist = SampleList(matrix_samples) + + @test collect(get_samples(matrix_samplelist)) == first.(collect(matrix_samplelist)) + @test collect(get_weights(matrix_samplelist)) == fill(one(type) / N, N) + @test deep_eltype(matrix_samplelist) === type + @test eltype(matrix_samplelist) === Tuple{SMatrix{2,2,type,4},type} + @test eltype(get_weights(matrix_samplelist)) === type + @test variate_form(typeof(matrix_samplelist)) === Matrixvariate + @test is_meta_present(matrix_samplelist) === false + + matrix_weights = rand(rng, type, N) + @test_throws AssertionError SampleList(matrix_samples, matrix_weights) + matrix_weights ./= sum(matrix_weights) + matrix_samplelist = SampleList(matrix_samples, matrix_weights) + + @test collect(get_samples(matrix_samplelist)) == first.(collect(matrix_samplelist)) + @test collect(get_weights(matrix_samplelist)) == matrix_weights + @test deep_eltype(matrix_samplelist) === type + @test eltype(matrix_samplelist) === Tuple{SMatrix{2,2,type,4},type} + @test eltype(get_weights(matrix_samplelist)) === type + @test variate_form(typeof(matrix_samplelist)) === Matrixvariate + @test is_meta_present(matrix_samplelist) === false + end + + @test_throws AssertionError SampleList(rand(10), rand(5)) + @test_throws AssertionError SampleList(rand(5), rand(10)) + @test_throws AssertionError SampleList(rand(5), [-1 for _ in 1:5]) + + @test_throws AssertionError SampleList([rand(10) for _ in 1:10], rand(5)) + @test_throws AssertionError SampleList([rand(5) for _ in 1:5], rand(10)) + @test_throws AssertionError SampleList([rand(5) for _ in 1:5], [-1 for _ in 1:5]) + + @test_throws AssertionError SampleList([rand(10, 10) for _ in 1:10], rand(5)) + @test_throws AssertionError SampleList([rand(5, 5) for _ in 1:5], rand(10)) + @test_throws AssertionError SampleList([rand(5, 5) for _ in 1:5], [-1 for _ in 1:5]) +end + +@testitem "SampleList: Statistics" begin + using StableRNGs, StaticArrays, Distributions, LinearAlgebra + + import BayesBase: + deep_eltype, + get_samples, + get_weights, + sample_list_zero_element, + get_meta, + is_meta_present, + get_unnormalised_weights, + get_entropy, + get_logproposal, + get_logintegrand, + call_logproposal, + call_logintegrand, + transform_samples, + transform_weights!, + approximate_prod_with_sample_list + + rng = StableRNG(42) + + # All + for N in [5, 10, 100] + scalar_samples = rand(rng, N) + scalar_weights = rand(rng, N) + scalar_weights ./= sum(scalar_weights) + scalar_samplelist = SampleList(scalar_samples, scalar_weights) + arbitrary_f = (x) -> x .+ 1 + + # Checking i = 1:2 that cache is not corrupted + for i in 1:2 + @test mean(scalar_samplelist) ≈ sum(scalar_weights .* scalar_samples) + @test mean(log, scalar_samplelist) ≈ sum(scalar_weights .* log.(scalar_samples)) + @test mean(xtlog, scalar_samplelist) ≈ + sum(scalar_weights .* scalar_samples .* log.(scalar_samples)) + @test mean(mirrorlog, scalar_samplelist) ≈ + sum(scalar_weights .* log.(1.0 .- scalar_samples)) + @test mean(arbitrary_f, scalar_samplelist) ≈ + sum(scalar_weights .* arbitrary_f.(scalar_samples)) + end + + vector_samples = [rand(rng, 2) for _ in 1:N] + vector_weights = rand(rng, N) + vector_weights ./= sum(vector_weights) + vector_samplelist = SampleList(vector_samples, vector_weights) + + # Checking i = 1:2 that cache is not corrupted + for i in 1:2 + @test mean(vector_samplelist) ≈ sum(vector_weights .* vector_samples) + @test mean(log, vector_samplelist) ≈ + sum(vector_weights .* map(e -> log.(e), (vector_samples))) + @test mean(xtlog, vector_samplelist) ≈ + sum(vector_weights .* map(e -> e .* log.(e), (vector_samples))) + @test mean(arbitrary_f, vector_samplelist) ≈ + sum(vector_weights .* map(arbitrary_f, (vector_samples))) + end + + matrix_samples = [rand(rng, 2, 2) for _ in 1:N] + matrix_weights = rand(rng, N) + matrix_weights ./= sum(matrix_weights) + matrix_samplelist = SampleList(matrix_samples, matrix_weights) + + # Checking i = 1:2 that cache is not corrupted + for i in 1:2 + @test mean(matrix_samplelist) ≈ sum(matrix_weights .* matrix_samples) + @test mean(log, matrix_samplelist) ≈ + sum(matrix_weights .* map(e -> log.(e), matrix_samples)) + @test mean(xtlog, matrix_samplelist) ≈ + sum(matrix_weights .* map(e -> e .* log.(e), matrix_samples)) + @test mean(arbitrary_f, matrix_samplelist) ≈ + sum(matrix_weights .* map(arbitrary_f, matrix_samples)) + end + end + + uni_distribution = Gamma(rand(rng) + 1, rand(rng) + 2) + uni_samples = rand(rng, uni_distribution, 20_000) + uni_sample_list = SampleList(uni_samples) + + uni_distribution2 = Normal(rand(rng) + 1, rand(rng) + 2) + uni_samples2 = rand(rng, uni_distribution2, 20_000) + uni_sample_list2 = SampleList(uni_samples2) + + m = rand(rng, 3) + r = rand(rng, 3) + Σ = I + 2r * r' + mv_distribution = MvNormal(m, Σ) + mv_samples = [rand(rng, mv_distribution) for _ in 1:20_000] + mv_sample_list = SampleList(mv_samples) + + W1 = rand(rng, 3, 4) + r2 = rand(rng, 3) + W2 = I + 2r2 * r2' + r3 = rand(rng, 4) + W3 = I + 2r3 * r3' + mxv_distribution = MatrixNormal(W1, W2, W3) + mxv_samples = [rand(rng, mxv_distribution) for _ in 1:20_000] + mxv_sample_list = SampleList(mxv_samples) + + # Checking i = 1:2 that cache is not corrupted + for i in 1:2 + @test isapprox(mean(uni_sample_list), mean(uni_distribution), atol=0.1) + @test isapprox(mean(mv_sample_list), mean(mv_distribution), atol=0.1) + @test isapprox(mean(mxv_sample_list), mean(mxv_distribution), atol=1.0) + @test all( + isapprox.(mean_var(uni_sample_list2), mean_var(uni_distribution2), atol=0.1) + ) + @test isapprox(var(uni_sample_list), var(uni_distribution), atol=0.5) + @test isapprox(cov(uni_sample_list), var(uni_distribution), atol=0.5) + @test isapprox(var(mv_sample_list), var(mv_distribution), atol=0.1) + @test isapprox(cov(mv_sample_list), cov(mv_distribution), atol=0.1) + @test isapprox(var(mxv_sample_list), var(mxv_distribution), atol=0.1) + @test isapprox(cov(mxv_sample_list), cov(mxv_distribution), atol=1.0) + + @test isapprox(std(uni_sample_list), std(uni_distribution), atol=0.2) + @test isapprox(std(mv_sample_list), sqrt(cov(mv_distribution)), atol=0.2) + end + + mv_distribution = Dirichlet(rand(rng, 3)) + mv_samples = [rand(rng, mv_distribution) for _ in 1:20_000] + mv_sample_list = SampleList(mv_samples) + + r4 = rand(rng, 5) + W4 = I + 2r4 * r4' + + mxv_distribution = Wishart(5, W4) + mxv_samples = [rand(rng, mxv_distribution) for _ in 1:20_000] + mxv_sample_list = SampleList(mxv_samples) + + # Checking i = 1:2 that cache is not corrupted + for i in 1:2 + @test isapprox(mean(mxv_sample_list), mean(mxv_distribution), atol=1.0) + @test isapprox(var(mxv_sample_list), var(mxv_distribution), atol=2.5) + @test isapprox(cov(mxv_sample_list), cov(mxv_distribution), atol=5.0) + end +end + +@testitem "SampleList: vague" begin + @test variate_form(typeof(vague(SampleList))) === Univariate + @test variate_form(typeof(vague(SampleList, 2))) === Multivariate + @test variate_form(typeof(vague(SampleList, 2, 2))) === Matrixvariate + @test variate_form(typeof(vague(SampleList, (3, 4)))) === Matrixvariate + + @test ndims(vague(SampleList)) === 1 + @test ndims(vague(SampleList, 2)) === 2 + @test ndims(vague(SampleList, 2, 2)) === (2, 2) + @test ndims(vague(SampleList, (3, 4))) === (3, 4) + + for nsamples in [10, 100, 1000] + @test variate_form(typeof(vague(SampleList; nsamples=nsamples))) === Univariate + @test variate_form(typeof(vague(SampleList, 2; nsamples=nsamples))) === Multivariate + @test variate_form(typeof(vague(SampleList, 2, 2; nsamples=nsamples))) === + Matrixvariate + @test variate_form(typeof(vague(SampleList, (3, 4); nsamples=nsamples))) === + Matrixvariate + + @test length(vague(SampleList; nsamples=nsamples)) === nsamples + @test length(vague(SampleList, 2; nsamples=nsamples)) === nsamples + @test length(vague(SampleList, 2, 2; nsamples=nsamples)) === nsamples + @test length(vague(SampleList, (3, 4); nsamples=nsamples)) === nsamples + + @test ndims(vague(SampleList; nsamples=nsamples)) === 1 + @test ndims(vague(SampleList, 2; nsamples=nsamples)) === 2 + @test ndims(vague(SampleList, 2, 2; nsamples=nsamples)) === (2, 2) + @test ndims(vague(SampleList, (3, 4); nsamples=nsamples)) === (3, 4) + + @test size(vague(SampleList; nsamples=nsamples)) === (nsamples,) + @test size(vague(SampleList, 2; nsamples=nsamples)) === (nsamples,) + @test size(vague(SampleList, 2, 2; nsamples=nsamples)) === (nsamples,) + @test size(vague(SampleList, (3, 4); nsamples=nsamples)) === (nsamples,) + end +end + +@testitem "SampleListMeta" begin + using StableRNGs, StaticArrays, Distributions, LinearAlgebra + + import BayesBase: + deep_eltype, + get_samples, + get_weights, + sample_list_zero_element, + get_meta, + is_meta_present, + get_unnormalised_weights, + get_entropy, + get_logproposal, + get_logintegrand, + call_logproposal, + call_logintegrand, + transform_samples, + transform_weights!, + approximate_prod_with_sample_list + + @test_throws ErrorException get_meta(SampleList([0.1])) + + rng = StableRNG(1234) + + for uweights in [rand(rng, 2), rand(rng, 2)], + entropy in [-75.6, 234.2], + logproposal in [Normal(-3.0, 1.0), Gamma(2.0, 4.0), (x) -> x + 1.0], + logintegrand in [Normal(-1.0, 2.0), Beta(3.0, 4.0), (x) -> log(abs(x))] + + meta = SampleListMeta(uweights, entropy, logproposal, logintegrand) + sl = SampleList([0.1, 0.1], [0.5, 0.5], meta) + + @test get_meta(sl) === meta + @test get_unnormalised_weights(sl) == uweights + @test get_entropy(sl) == entropy + @test get_logproposal(sl) == logproposal + @test get_logintegrand(sl) == logintegrand + @test is_meta_present(sl) === true + + some_random_numbers = rand(rng, 100) + + __call(dist::Distribution, x) = logpdf(dist, x) + __call(dist::Function, x) = dist(x) + + @test map(e -> call_logproposal(sl, e), some_random_numbers) == + map(e -> __call(logproposal, e), some_random_numbers) + @test map(e -> call_logintegrand(sl, e), some_random_numbers) == + map(e -> __call(logintegrand, e), some_random_numbers) + end +end + +@testitem "Iteration utilities" begin + using StableRNGs, StaticArrays, Distributions, LinearAlgebra + + import BayesBase: + deep_eltype, + get_samples, + get_weights, + sample_list_zero_element, + get_meta, + is_meta_present, + get_unnormalised_weights, + get_entropy, + get_logproposal, + get_logintegrand, + call_logproposal, + call_logintegrand, + transform_samples, + transform_weights!, + approximate_prod_with_sample_list + + rng = StableRNG(42) + + uni_distribution = Uniform(-10rand(rng), 10rand(rng)) + + μ = rand(rng, 3) + L1 = rand(rng, 3, 3) + Σ = L1' * L1 + mv_distribution = MvNormal(μ, Σ) + + L2 = rand(rng, 3, 3) + W = L2' * L2 + mvx_distribution = Wishart(3, W) + + # Entity to entity + f1(e) = e .+ 1 + f2(e) = exp.(e) + + # Entity to Number + f3(e::Number) = e + 1 + f3(e::AbstractVector) = norm(e .+ 1) + f3(e::AbstractMatrix) = det(e .+ 1) + + # Entity to Vector + f4(e::Number) = @SVector [e, e] + f4(e::AbstractVector) = reverse(e) + f4(e::AbstractMatrix) = diag(e) + + # Entity to Matrix + f5(e::Number) = @SMatrix [e+1 e; e e+1] + f5(e::AbstractVector) = SMatrix{length(e),length(e)}(Diagonal(ones(length(e)))) + f5(e::AbstractMatrix) = [e[1, 1]+1 e[1, 2]; e[2, 1] e[2, 2]+1] + + for N in (500, 1000, 5_000) + for distribution in (uni_distribution, mv_distribution, mvx_distribution) + samples = [rand(rng, distribution) for _ in 1:N] + weights = ones(N) ./ N + samplelist = SampleList(samples, weights) + + @test collect(samplelist) == collect(zip(samples, weights)) + @test map(i -> samplelist[i], 1:N) == collect(zip(samples, weights)) + + for f in (f1, f2, f3, f4, f5) + @test all( + map( + e -> all(e[1] .≈ e[2]), + zip( + collect(transform_samples(f, samplelist)), + collect(zip(map(f, samples), weights)), + ), + ), + ) + @test all( + map( + e -> all(e[1] .≈ e[2]), + zip( + map(i -> (f(samplelist[i][1]), samplelist[i][2]), 1:N), + collect(zip(f.(samples), weights)), + ), + ), + ) + end + + iter = N:-1:1 + index = 0 + + old_weights = copy(weights) + + transform_weights!(w -> w * iter[index += 1], samplelist) + + newweights = map(prod, zip(old_weights, iter)) + newweights ./= sum(newweights) + + @test get_weights(samplelist) ≈ newweights + end + end +end + +@testitem "prod approximations" begin + using StableRNGs, StaticArrays, Distributions, LinearAlgebra + using SpecialFunctions: loggamma, digamma + + import BayesBase: + deep_eltype, + get_samples, + get_weights, + sample_list_zero_element, + get_meta, + is_meta_present, + get_unnormalised_weights, + get_entropy, + get_logproposal, + get_logintegrand, + call_logproposal, + call_logintegrand, + transform_samples, + transform_weights!, + approximate_prod_with_sample_list + + rng = StableRNG(1234) + + posdefm(rng, s) = begin + L = LowerTriangular(rand(rng, s, s)) + s * I + L' * L + end + + sizes = [10_000, 20_000, 30_000] + inputs = [ + ( + x=Normal(3.0, inv(7.0)), + y=Normal(-4.0, sqrt(6.0)), + mean_tol=[1e-2, 1e-2, 1e-2], + cov_tol=[1e-2, 1e-2, 1e-2], + entropy_tol=[1e-2, 1e-2, 1e-2], + ), + ( + x=Normal(3.0, sqrt(7.0)), + y=Normal(inv(6.0) * 4.0, sqrt(inv(6.0))), + mean_tol=[5e-2, 5e-2, 5e-2], + cov_tol=[1e-2, 1e-2, 1e-2], + entropy_tol=[5e-2, 5e-2, 5e-2], + ), + ( + x=Gamma(3.0, 1.0 / 7.0), + y=Gamma(4.0, 6.0), + mean_tol=[1e-2, 1e-2, 1e-2], + cov_tol=[1e-2, 1e-2, 1e-2], + entropy_tol=[3e-2, 3e-2, 3e-2], + ), + ( + x=MvNormal(10rand(rng, 4), posdefm(rng, 4)), + y=MvNormal(10rand(rng, 4), posdefm(rng, 4)), + mean_tol=[3e-1, 3e-1, 3e-1], + cov_tol=[6e-1, 6e-1, 6e-1], + entropy_tol=[4e-2, 4e-2, 4e-2], + ), + ( + x=InverseWishart(10.0, posdefm(rng, 2)), + y=InverseWishart(5.0, posdefm(rng, 2)), + mean_tol=[7e-2, 7e-2, 7e-2], + cov_tol=[5e-2, 5e-2, 5e-2], + entropy_tol=[2.0, 2.0, 2.0], # this tol is quite high + ), + ] + + struct AnalyticalProdForTesting end + + function Base.prod(::AnalyticalProdForTesting, left::Normal, right::Normal) + μ = (mean(left) * var(right) + mean(right) * var(left)) / (var(right) + var(left)) + v = (var(left) * var(right)) / (var(left) + var(right)) + return Normal(μ, sqrt(v)) + end + + function Base.prod(::AnalyticalProdForTesting, left::Gamma, right::Gamma) + return Gamma( + shape(left) + shape(right) - 1, + (scale(left) * scale(right)) / (scale(left) + scale(right)), + ) + end + + function Base.prod(::AnalyticalProdForTesting, left::MvNormal, right::MvNormal) + invcovleft = inv(cholesky(cov(left))) + invcovright = inv(cholesky(cov(right))) + Σ = Matrix(Hermitian(inv(cholesky(invcovleft + invcovright)))) + μ = Σ * (invcovleft * mean(left) + invcovright * mean(right)) + return MvNormal(μ, Σ) + end + + function Base.prod(::AnalyticalProdForTesting, left::InverseWishart, right::InverseWishart) + d = size(left, 1) + ldf, lS = params(left) + rdf, rS = params(right) + V = lS + rS + df = ldf + rdf + d + 1 + return InverseWishart(df, V) + end + + function Distributions.entropy(dist::InverseWishart) + d = size(dist, 1) + ν, S = params(dist) + d * (d - 1) / 4 * log(π) + mapreduce(i -> loggamma((ν + 1.0 - i) / 2), +, 1:d) + ν / 2 * d + (d + 1) / 2 * (logdet(S) - log(2)) - + (ν + d + 1) / 2 * mapreduce(i -> digamma((ν - d + i) / 2), +, 1:d) + end + + for (i, N) in enumerate(sizes) + for input in inputs + @testset let input = input + analytical = prod(AnalyticalProdForTesting(), input[:x], input[:y]) + approximation = approximate_prod_with_sample_list( + rng, input[:x], input[:y], N + ) + + @test is_meta_present(approximation) === true + @test length(approximation) === N + + μᵣ, Σᵣ = nothing, nothing + μₐ, Σₐ = nothing, nothing + + if (variate_form(typeof(input[:x])) === Univariate) + μᵣ, Σᵣ = mean_var(analytical) + μₐ, Σₐ = mean_var(approximation) + else + μᵣ, Σᵣ = mean_cov(analytical) + μₐ, Σₐ = mean_cov(approximation) + end + + @test norm(μᵣ .- μₐ) < input[:mean_tol][i] + @test norm(Σᵣ .- Σₐ) < input[:cov_tol][i] + @test abs(entropy(analytical) - entropy(approximation)) < + input[:entropy_tol][i] + + # Second order approximation here + if (variate_form(typeof(input[:x])) === Univariate) + analytical2 = prod(AnalyticalProdForTesting(), analytical, input[:x]) + approximation2 = approximate_prod_with_sample_list( + rng, input[:x], approximation, N + ) + + @test is_meta_present(approximation2) === true + @test length(approximation2) === N + + if (variate_form(typeof(input[:x])) === Univariate) + μᵣ, Σᵣ = mean_var(analytical) + μₐ, Σₐ = mean_var(approximation) + else + μᵣ, Σᵣ = mean_cov(analytical) + μₐ, Σₐ = mean_cov(approximation) + end + + @test norm(μᵣ .- μₐ) < input[:mean_tol][i] + @test norm(Σᵣ .- Σₐ) < input[:cov_tol][i] + @test abs(entropy(analytical2) - entropy(approximation2)) < + input[:entropy_tol][i] + end + end + end + end +end \ No newline at end of file