Skip to content

Commit

Permalink
test: use test_exponentialfamily_interface and add MvNormalMeanScaleP…
Browse files Browse the repository at this point in the history
…recision efficency test

fix: remove unneeded code

fix: remove not needed stuff

fix: remove unused code

test: add efficency test

fix: return distributions_setuptests to HEAD

test(fix): typo

test(fix): remove unneeded testset

test(fix): update efficency test
  • Loading branch information
Nimrais committed Sep 25, 2024
1 parent e319963 commit 77f4a0d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 9 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
[compat]
Aqua = "0.8.7"
BayesBase = "1.2"
BlockArrays = "1.1.1"
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
FastCholesky = "1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,14 @@ getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanScalePrecision
η2 = η[end]
k = length(η1)

inv_η2 = inv(η2)
η1_part = -1/(2*inv_η2)* I(length(η1))
η1_part = -inv(2*η2)* I(length(η1))
η1η2 = zeros(k, 1)
η1η2 .= 2*η1/inv_η2^2
η1η2 .= η1*inv(2*η2^2)
#η₁/(2abs2(η₂))

η2_part = zeros(1, 1)
η2_part .= -dot(η1,η1) / 2*inv_η2^3 + k/(2inv_η2)
η2_part .= k*inv(2abs2(η2)) - dot(η1,η1) / (2*η2^3)
# inv(2abs2(η₂))-abs2(η₁)/(2(η₂^3))

fisher = BlockArray{eltype(η)}(undef_blocks, [k, 1], [k, 1])

Expand All @@ -272,10 +273,10 @@ getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision})
μ_part = γ * I(k)

μγ_part = zeros(k, 1)
μγ_part .= μ
μγ_part .= 0

γ_part = zeros(1, 1)
γ_part .= k/(2*γ^2)
γ_part .= k*inv(2abs2(γ))

fisher = BlockArray{eltype(θ)}(undef_blocks, [k, 1], [k, 1])

Expand All @@ -285,4 +286,4 @@ getfisherinformation(::MeanParametersSpace, ::Type{MvNormalMeanScalePrecision})
fisher[Block(2), Block(2)] = γ_part

return fisher
end
end
2 changes: 1 addition & 1 deletion test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,4 +557,4 @@ function test_generic_simple_exponentialfamily_product(
end

return true
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,27 @@ end

rng = StableRNG(42)

for s in 2:5
for s in 1:6
μ = randn(rng, s)
γ = rand(rng)

@testset let d = MvNormalMeanScalePrecision(μ, γ)
ef = test_exponentialfamily_interface(d;)
end
end

μ = randn(rng, 1)
γ = rand(rng)

d = MvNormalMeanScalePrecision(μ, γ)
ef = convert(ExponentialFamilyDistribution, d)

d1d = NormalMeanPrecision(μ[1], γ)
ef1d = convert(ExponentialFamilyDistribution, d1d)

@test logpartition(ef) logpartition(ef1d)
@test gradlogpartition(ef) gradlogpartition(ef1d)
@test fisherinformation(ef) fisherinformation(ef1d)
end

@testitem "MvNormalMeanScalePrecision: Stats methods" begin
Expand Down Expand Up @@ -164,3 +177,52 @@ end
end
end
end

@testitem "MvNormalMeanScalePrecision: Fisher is faster then for full parametrization" begin
include("./normal_family_setuptests.jl")
using BenchmarkTools
using LinearAlgebra
using JET

rng = StableRNG(42)
for k in 20:40
μ = randn(rng, k)
γ = rand(rng)
cov = γ * I(k)

ef_small = convert(ExponentialFamilyDistribution, MvNormalMeanScalePrecision(μ, γ))
ef_full = convert(ExponentialFamilyDistribution, MvNormalMeanCovariance(μ, cov))

fi_small = fisherinformation(ef_small)
fi_full = fisherinformation(ef_full)

@test_opt fisherinformation(ef_small)
@test_opt fisherinformation(ef_full)

fi_mvsp_time = @elapsed fisherinformation(ef_small)
fi_mvsp_alloc = @allocated fisherinformation(ef_small)

fi_full_time = @elapsed fisherinformation(ef_full)
fi_full_alloc = @allocated fisherinformation(ef_full)

@test_opt cholinv(fi_small)
@test_opt cholinv(fi_full)

cholinv_time_small = @elapsed cholinv(fi_small)
cholinv_alloc_small = @allocated fisherinformation(ef_small)

cholinv_time_full = @elapsed cholinv(fi_full)
cholinv_alloc_full = @allocated cholinv(fi_full)

fi_small = fisherinformation(ef_small)
fi_full = fisherinformation(ef_full)

# small time is supposed to be O(k) and full time is supposed to O(k^2)
# the constant C is selected to account to fluctuations in test runs
C = 0.9
@test fi_mvsp_time < fi_full_time/(C*k)
@test fi_mvsp_alloc < fi_full_alloc/(C*k)
@test cholinv_time_small < cholinv_time_full/(C*k)
@test cholinv_alloc_small < cholinv_alloc_full/(C*k)
end
end

0 comments on commit 77f4a0d

Please sign in to comment.