Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MvNormalMeanScalePrecision distribution #206

Merged
merged 33 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
97488a3
Add tests for MvNormalMeanScalePrecision
albertpod Aug 8, 2024
a590f7f
Add MvNormalMeanScalePrecision
albertpod Aug 8, 2024
d78da98
Fix distribution
albertpod Aug 8, 2024
45d0f59
Fix tests
albertpod Aug 9, 2024
9818bd3
Update structure and tests
albertpod Aug 9, 2024
22e611a
Add natural parameters related functions
albertpod Aug 12, 2024
3f467b8
Merge branch 'main' into dev_mvscalenormal
albertpod Aug 12, 2024
c9ad326
WIP: Parameters transforamtion
albertpod Aug 14, 2024
a0ca848
Add fisher information
albertpod Aug 15, 2024
8a37b2c
Add fisher tests
albertpod Aug 21, 2024
1260dd3
Add rand
albertpod Aug 21, 2024
d8b2370
Add MvNormalMeanScalePrecision to library.md
albertpod Aug 21, 2024
44e2ce6
test: add test exponentialfamily interface for MvNormalMeanScalePreci…
Nimrais Sep 20, 2024
49670f8
feat: add basic functions for MvNormalMeanScalePrecision
Nimrais Sep 20, 2024
1877a70
feat: draft MvNormalMeanScalePrecision
Nimrais Sep 20, 2024
118ccfd
fix: dimension match
Nimrais Sep 20, 2024
9d6159a
test: add check that samples are correct
Nimrais Sep 23, 2024
2fb5717
feat: implement getfisherinformation(::NaturalParametersSpace, ::Type…
Nimrais Sep 23, 2024
89a4932
feat: implement getfisherinformation(::NaturalParametersSpace, ::Type…
Nimrais Sep 23, 2024
e319963
fix: correct getfisherinformation(::MeanParametersSpace, ::Type{MvNor…
Nimrais Sep 24, 2024
77f4a0d
test: use test_exponentialfamily_interface and add MvNormalMeanScaleP…
Nimrais Sep 24, 2024
0b87569
Delete test/repopack-output.txt
Nimrais Sep 25, 2024
575c4af
Update test/distributions/normal_family/mv_normal_mean_scale_precisio…
Nimrais Sep 26, 2024
2178b10
test(fix): typo in @allocated cholinv(fi_small)
Nimrais Sep 26, 2024
2a5db13
fix: MvNormalMeanScalePrecision should be faster from 10 dimensions
Nimrais Oct 11, 2024
7a00b40
fix: bump BayesBase 1.4.0
Nimrais Oct 11, 2024
148d140
refactor: mean param fisher for MvNormalMeanScalePrecision
Nimrais Oct 11, 2024
9990280
fix: remove BlockArrays
Nimrais Oct 11, 2024
98f2343
fix: update BayesBase 1.5.0
Nimrais Oct 22, 2024
fc1103d
fix: use rand! in rand for MvGaussianMeanScalePrecision
Nimrais Oct 22, 2024
9bd15a9
fix: make C=0.7 for Fisher is faster test
Nimrais Oct 22, 2024
809dc67
test(fix): use benchmark
Nimrais Oct 22, 2024
24108bb
Change nr of samples belapsed and # dimensions
wouterwln Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"

[compat]
Aqua = "0.8.7"
BayesBase = "1.2"
BayesBase = "1.5.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there stuff in BayesBase 1.5.0 that you need specifically?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the arrowheadmatrix and the all methods for it to make the CI pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
FastCholesky = "1.0"
Expand Down Expand Up @@ -57,8 +57,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
1 change: 1 addition & 0 deletions docs/src/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ExponentialFamily.NormalWeightedMeanPrecision
ExponentialFamily.MvNormalMeanPrecision
ExponentialFamily.MvNormalMeanCovariance
ExponentialFamily.MvNormalWeightedMeanPrecision
ExponentialFamily.MvNormalMeanScalePrecision
ExponentialFamily.JointNormal
ExponentialFamily.JointGaussian
ExponentialFamily.WishartFast
Expand Down
1 change: 1 addition & 0 deletions src/ExponentialFamily.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include("distributions/normal_family/mv_normal_mean_covariance.jl")
include("distributions/normal_family/mv_normal_mean_precision.jl")
include("distributions/normal_family/mv_normal_weighted_mean_precision.jl")
include("distributions/normal_family/normal_family.jl")
include("distributions/normal_family/mv_normal_mean_scale_precision.jl")
include("distributions/gamma_inverse.jl")
include("distributions/geometric.jl")
include("distributions/matrix_dirichlet.jl")
Expand Down
268 changes: 268 additions & 0 deletions src/distributions/normal_family/mv_normal_mean_scale_precision.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
export MvNormalMeanScalePrecision, MvGaussianMeanScalePrecision

import Distributions: logdetcov, distrname, sqmahal, sqmahal!, AbstractMvNormal
import LinearAlgebra: diag, Diagonal, dot
import Base: ndims, precision, length, size, prod

"""
MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal

A multivariate normal distribution with mean `μ` and scale parameter `γ` that scales the identity precision matrix.

# Type Parameters
- `T`: The element type of the mean vector and scale parameter
- `M`: The type of the mean vector, which must be a subtype of `AbstractVector{T}`

# Fields
- `μ::M`: The mean vector of the multivariate normal distribution
- `γ::T`: The scale parameter that scales the identity precision matrix

# Notes
The precision matrix of this distribution is `γ * I`, where `I` is the identity matrix.
The covariance matrix is the inverse of the precision matrix, i.e., `(1/γ) * I`.
"""
struct MvNormalMeanScalePrecision{T <: Real, M <: AbstractVector{T}} <: AbstractMvNormal
μ::M
γ::T
end

const MvGaussianMeanScalePrecision = MvNormalMeanScalePrecision

function MvNormalMeanScalePrecision(μ::AbstractVector{<:Real}, γ::Real)
T = promote_type(eltype(μ), eltype(γ))
return MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
end

function MvNormalMeanScalePrecision(μ::AbstractVector{<:Integer}, γ::Real)
return MvNormalMeanScalePrecision(float.(μ), float(γ))
end

function MvNormalMeanScalePrecision(μ::AbstractVector{T}) where {T}
return MvNormalMeanScalePrecision(μ, convert(T, 1))
end

function MvNormalMeanScalePrecision(μ::AbstractVector{T1}, γ::T2) where {T1, T2}
T = promote_type(T1, T2)
μ_new = convert(AbstractArray{T}, μ)
γ_new = convert(T, γ)(length(μ))
return MvNormalMeanScalePrecision(μ_new, γ_new)
end

function unpack_parameters(::Type{MvNormalMeanScalePrecision}, packed)
p₁ = view(packed, 1:length(packed)-1)
p₂ = packed[end]

return (p₁, p₂)
end

function isproper(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}, η, conditioner)
k = length(η) - 1
if length(η) < 2 || (length(η) !== k + 1)
return false
end
(η₁, η₂) = unpack_parameters(MvNormalMeanScalePrecision, η)
return isnothing(conditioner) && isone(size(η₂, 1)) && isposdef(-η₂)
end

function (::MeanToNatural{MvNormalMeanScalePrecision})(tuple_of_θ::Tuple{Any, Any})
(μ, γ) = tuple_of_θ
return (γ * μ, - γ / 2)
end

function (::NaturalToMean{MvNormalMeanScalePrecision})(tuple_of_η::Tuple{Any, Any})
(η₁, η₂) = tuple_of_η
γ = -2 * η₂
return (η₁ / γ, γ)
end

function nabs2(x)
return sum(map(abs2, x))
end

getsufficientstatistics(::Type{MvNormalMeanScalePrecision}) = (identity, nabs2)

# Conversions
function Base.convert(
::Type{MvNormal{T, C, M}},
dist::MvNormalMeanScalePrecision
) where {T <: Real, C <: Distributions.PDMats.PDMat{T, Matrix{T}}, M <: AbstractVector{T}}
m, σ = mean(dist), std(dist)
return MvNormal(convert(M, m), convert(T, σ))
end

function Base.convert(
::Type{MvNormalMeanScalePrecision{T, M}},
dist::MvNormalMeanScalePrecision
) where {T <: Real, M <: AbstractArray{T}}
m, γ = mean(dist), dist.γ
return MvNormalMeanScalePrecision{T, M}(convert(M, m), convert(T, γ))
end

function Base.convert(
::Type{MvNormalMeanScalePrecision{T}},
dist::MvNormalMeanScalePrecision
) where {T <: Real}
return convert(MvNormalMeanScalePrecision{T, AbstractArray{T, 1}}, dist)
end

function Base.convert(::Type{MvNormalMeanCovariance}, dist::MvNormalMeanScalePrecision)
m, σ = mean(dist), cov(dist)
return MvNormalMeanCovariance(m, σ * diagm(ones(length(m))))
end

function Base.convert(::Type{MvNormalMeanPrecision}, dist::MvNormalMeanScalePrecision)
m, γ = mean(dist), precision(dist)
return MvNormalMeanPrecision(m, γ * diagm(ones(length(m))))
end

function Base.convert(::Type{MvNormalWeightedMeanPrecision}, dist::MvNormalMeanScalePrecision)
m, γ = mean(dist), precision(dist)
return MvNormalWeightedMeanPrecision(γ * m, γ * diagm(ones(length(m))))
end

Distributions.distrname(::MvNormalMeanScalePrecision) = "MvNormalMeanScalePrecision"

BayesBase.weightedmean(dist::MvNormalMeanScalePrecision) = precision(dist) * mean(dist)

BayesBase.mean(dist::MvNormalMeanScalePrecision) = dist.μ
BayesBase.mode(dist::MvNormalMeanScalePrecision) = mean(dist)
BayesBase.var(dist::MvNormalMeanScalePrecision) = diag(cov(dist))
BayesBase.cov(dist::MvNormalMeanScalePrecision) = cholinv(invcov(dist))
BayesBase.invcov(dist::MvNormalMeanScalePrecision) = scale(dist) * I(length(mean(dist)))
BayesBase.std(dist::MvNormalMeanScalePrecision) = cholsqrt(cov(dist))
BayesBase.logdetcov(dist::MvNormalMeanScalePrecision) = -chollogdet(invcov(dist))
BayesBase.scale(dist::MvNormalMeanScalePrecision) = dist.γ
BayesBase.params(dist::MvNormalMeanScalePrecision) = (mean(dist), scale(dist))

function Distributions.sqmahal(dist::MvNormalMeanScalePrecision, x::AbstractVector)
T = promote_type(eltype(x), paramfloattype(dist))
return sqmahal!(similar(x, T), dist, x)
end

function Distributions.sqmahal!(r, dist::MvNormalMeanScalePrecision, x::AbstractVector)
Nimrais marked this conversation as resolved.
Show resolved Hide resolved
μ, γ = params(dist)
@inbounds @simd for i in 1:length(r)
r[i] = μ[i] - x[i]
end
return dot3arg(r, γ, r) # x' * A * x
end

Base.eltype(::MvNormalMeanScalePrecision{T}) where {T} = T
Base.precision(dist::MvNormalMeanScalePrecision) = invcov(dist)
Base.length(dist::MvNormalMeanScalePrecision) = length(mean(dist))
Base.ndims(dist::MvNormalMeanScalePrecision) = length(dist)
Base.size(dist::MvNormalMeanScalePrecision) = (length(dist),)

Base.convert(::Type{<:MvNormalMeanScalePrecision}, μ::AbstractVector, γ::Real) = MvNormalMeanScalePrecision(μ, γ)

function Base.convert(::Type{<:MvNormalMeanScalePrecision{T}}, μ::AbstractVector, γ::Real) where {T <: Real}
MvNormalMeanScalePrecision(convert(AbstractArray{T}, μ), convert(T, γ))
end

BayesBase.vague(::Type{<:MvNormalMeanScalePrecision}, dims::Int) =
MvNormalMeanScalePrecision(zeros(Float64, dims), convert(Float64, tiny))

BayesBase.default_prod_rule(::Type{<:MvNormalMeanScalePrecision}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::MvNormalMeanScalePrecision, right::MvNormalMeanScalePrecision)
w = scale(left) + scale(right)
m = (scale(left) * mean(left) + scale(right) * mean(right)) / w
return MvNormalMeanScalePrecision(m, w)
end

BayesBase.default_prod_rule(::Type{<:MultivariateNormalDistributionsFamily}, ::Type{<:MvNormalMeanScalePrecision}) = PreserveTypeProd(Distribution)

function BayesBase.prod(
::PreserveTypeProd{Distribution},
left::L,
right::R
) where {L <: MultivariateNormalDistributionsFamily, R <: MvNormalMeanScalePrecision}
wleft = convert(MvNormalWeightedMeanPrecision, left)
wright = convert(MvNormalWeightedMeanPrecision, right)
return prod(BayesBase.default_prod_rule(wleft, wright), wleft, wright)
end

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}) where {T}
Nimrais marked this conversation as resolved.
Show resolved Hide resolved
μ, γ = params(dist)
d = length(μ)
return rand!(rng, dist, Vector{T}(undef, d))
end

function BayesBase.rand(rng::AbstractRNG, dist::MvGaussianMeanScalePrecision{T}, size::Int64) where {T}
container = Matrix{T}(undef, length(dist), size)
return rand!(rng, dist, container)
end

# FIXME: This is not the most efficient way to generate random samples within container
# it needs to work with scale method, not with std
function BayesBase.rand!(
Nimrais marked this conversation as resolved.
Show resolved Hide resolved
rng::AbstractRNG,
dist::MvGaussianMeanScalePrecision,
container::AbstractArray{T}
) where {T <: Real}
preallocated = similar(container)
randn!(rng, reshape(preallocated, length(preallocated)))
μ, L = mean_std(dist)
@views for i in axes(preallocated, 2)
copyto!(container[:, i], μ)
mul!(container[:, i], L, preallocated[:, i], 1, 1)
end
container
end

function getsupport(ef::ExponentialFamilyDistribution{MvNormalMeanScalePrecision})
dim = length(getnaturalparameters(ef)) - 1
return Domain(IndicatorFunction{AbstractVector}(MvNormalDomainIndicator(dim)))
end

isbasemeasureconstant(::Type{MvNormalMeanScalePrecision}) = ConstantBaseMeasure()

getbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> (2π)^(-length(x) / 2)

getlogbasemeasure(::Type{MvNormalMeanScalePrecision}) = (x) -> -length(x) / 2 * log2π

getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
k = length(η1)
Cinv = inv(η2)
return -dot(η1, 1/4*Cinv, η1) - (k / 2)*log(-2*η2)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
inv2 = inv(η2)
k = length(η1)
return pack_parameters(MvNormalMeanCovariance, (-1/(2*η2) * η1, dot(η1,η1) / 4*inv2^2 - k/2 * inv2))
end

getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(η) -> begin
η1 = @view η[1:end-1]
η2 = η[end]
k = length(η1)

η1_part = -inv(2*η2)* I(length(η1))
η1η2 = zeros(k, 1)
η1η2 .= η1*inv(2*η2^2)

η2_part = k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3)

return ArrowheadMatrix(η2_part, η1η2, diag(η1_part))
end


getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision}) =
(θ) -> begin
μ = @view θ[1:end-1]
γ = θ[end]
k = length(μ)

matrix = zeros(eltype(μ), (k+1))
matrix[1:k] .= γ
matrix[k+1] = k*inv(2abs2(γ))
return Diagonal(matrix)
end
2 changes: 1 addition & 1 deletion test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product(
end

return true
end
end
Loading
Loading