Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement gradlogpartition for Exponential Family Distributions #149

Merged
merged 79 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
b1a130c
feat: implement MvNormalMeanCovariance gradlogpartition
Nimrais Dec 14, 2023
f831e49
test(fix): make nsample to compute sufficient_statistics a keyword ar…
Nimrais Dec 14, 2023
9921d34
test(fix): make gradlogpartition test a bit more obvious
Nimrais Dec 14, 2023
65f5e6b
feat: add getgradlogpartition for VonMises
Nimrais Dec 14, 2023
ed317d2
add beta grad log partition
Dec 15, 2023
7086c08
add grad binomial
Dec 15, 2023
3389b2a
add gradlogpartition function
HoangMHNguyen Dec 15, 2023
f274500
add getgradlogpartition poisson
HoangMHNguyen Dec 15, 2023
0306250
add getgradlogparition exponential
HoangMHNguyen Dec 15, 2023
f05cfd3
Merge branch 'implement-grad-logpartition' into gradlogpartition_expo…
Nimrais Dec 15, 2023
9ad7b0a
fix: grad is a vector, not a scalar
Nimrais Dec 15, 2023
decb514
Merge pull request #152 from biaslab/gradlogpartition_exponential
Nimrais Dec 15, 2023
ae07421
fix: return gradlog as vector
HoangMHNguyen Dec 15, 2023
5f849dc
fix: return gradlog as vector
HoangMHNguyen Dec 15, 2023
4c591f9
add getgradlogparition exponential
HoangMHNguyen Dec 15, 2023
f6a760b
feat: implement MvNormalMeanCovariance gradlogpartition
Nimrais Dec 14, 2023
076b656
test(fix): make nsample to compute sufficient_statistics a keyword ar…
Nimrais Dec 14, 2023
0fc2d84
test(fix): make gradlogpartition test a bit more obvious
Nimrais Dec 14, 2023
9a585bd
feat: add getgradlogpartition for VonMises
Nimrais Dec 14, 2023
223719b
fix: grad is a vector, not a scalar
Nimrais Dec 15, 2023
e25b3f7
Merge pull request #151 from biaslab/gradlogpartition_poisson
Nimrais Dec 15, 2023
bd4856a
Merge pull request #150 from biaslab/gradlogpartition_bernoulli
Nimrais Dec 15, 2023
8138d8f
Dirichlet, Gamma and Geometric distributions gradlogpartition impleme…
wouterwln Dec 15, 2023
85a651e
Add gradient calculation for log partition
bartvanerp Dec 18, 2023
094558d
Merge pull request #153 from biaslab/grad_binomial
Nimrais Dec 18, 2023
3954fa3
Merge pull request #154 from biaslab/grad_beta
Nimrais Dec 18, 2023
4991bb6
fix wishart gradient
Dec 18, 2023
ba3d4fb
fix(normal): typo in getgradlogpartition(::NaturalParametersSpace, ::…
Nimrais Dec 18, 2023
462b272
Merge pull request #157 from biaslab/grad-uvnormal
Nimrais Dec 18, 2023
993105f
finalize wishart
Dec 18, 2023
f3668a1
test(fix): convergence for wishart could be a bit slower on arm64
Nimrais Dec 18, 2023
83d8729
Merge pull request #158 from biaslab/wishart_grad
Nimrais Dec 18, 2023
f2d469f
add gradient
Dec 18, 2023
3b99a7b
add gradient of weibull
Dec 18, 2023
af1c6fe
Merge pull request #159 from biaslab/grad_wishart_inverse
Nimrais Dec 18, 2023
7c315c6
Merge pull request #160 from biaslab/grad_weibull
Nimrais Dec 18, 2023
b66905c
add gradient of vmf
Dec 18, 2023
f7cbf44
add Rayleigh
Dec 18, 2023
396fa48
add gradient pareto
Dec 18, 2023
1fdaf97
add normal gamma gradient and move mvdigamma function to common
Dec 18, 2023
401094a
add negative binomial gradient
Dec 18, 2023
5fab5c8
add gradient of matrix dirichlet
Dec 18, 2023
0a65861
Merge pull request #164 from biaslab/grad_normal_gamma
Nimrais Dec 18, 2023
df7dd52
Merge pull request #165 from biaslab/grad_negative_binomial
Nimrais Dec 18, 2023
0199519
Merge pull request #163 from biaslab/grad_pareto
Nimrais Dec 18, 2023
6fb6491
Merge pull request #162 from biaslab/grad_rayleigh
Nimrais Dec 18, 2023
ad518d8
feat: add chisq grad logpartition
Nimrais Dec 18, 2023
83c0d8e
add gradient of log normal
Dec 18, 2023
044dc86
Merge pull request #167 from biaslab/grad_log_normal
Nimrais Dec 18, 2023
5fb7daa
Merge pull request #161 from biaslab/grad_vmf
Nimrais Dec 18, 2023
5fdb47a
add gradient of erlang
Dec 18, 2023
6c369dd
add gradient for categorical
Dec 18, 2023
d50e6bd
Merge pull request #170 from biaslab/grad_categorical
Nimrais Dec 18, 2023
e5a849b
add gradient inverse gamma
Dec 18, 2023
5d6fa7c
Merge pull request #171 from biaslab/grad_inverse_gamma
Nimrais Dec 18, 2023
4af2509
Merge pull request #166 from biaslab/grad_matrix_dirichlet
Nimrais Dec 18, 2023
6bf3a6b
add gradient for Laplace
Dec 18, 2023
258198b
Merge pull request #172 from biaslab/grad_laplace
Nimrais Dec 18, 2023
d1cf1d6
Update README.md
bartvanerp Jan 2, 2024
31d5e51
add: kurtosis and skewness, fix: piracy to piracies
Jan 7, 2024
440e43d
add skewness for univariate normals and cast cov of a univariate to var
Jan 7, 2024
575a16a
remove cov=var statement
Jan 9, 2024
d1301e2
add tests
bvdmitri Jan 9, 2024
05d7930
make format
bvdmitri Jan 9, 2024
5093e67
Revert `piracy = false` change
bvdmitri Jan 9, 2024
4c8fa10
Use Julia 1.10 for tests
bvdmitri Jan 9, 2024
a7df4ef
2prev
bvdmitri Jan 9, 2024
2d87607
Fix isapprox for Normal family of distributions
bvdmitri Jan 9, 2024
f563f7f
Update README.md
albertpod Jan 10, 2024
8413bf4
Update examples.md
albertpod Jan 10, 2024
ca8e43b
Update make.jl
albertpod Jan 10, 2024
8b9df24
Update gamma_shape_rate_tests.jl
albertpod Jan 10, 2024
9edc9ea
adjust docs deployment settings
bvdmitri Jan 12, 2024
46d662b
fix: use StableRNG
Nimrais Jan 22, 2024
8e5eeff
test: do not test gradient for MvNormalWishart
Nimrais Jan 22, 2024
6be1d71
add gradient of erlang
Dec 18, 2023
ee9c9fa
Merge pull request #168 from ReactiveBayes/grad_erlang
Nimrais Jan 22, 2024
64c3330
docs: remove TODO from error messages
Nimrais Jan 22, 2024
6f31e89
Add more tests
bvdmitri Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ isproper
getbasemeasure
getsufficientstatistics
getlogpartition
getgradlogpartition
getfisherinformation
getsupport
basemeasure
sufficientstatistics
logpartition
gradlogpartition
fisherinformation
isbasemeasureconstant
ConstantBaseMeasure
Expand Down
2 changes: 2 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ function binomial_prod(n, p, x)
end
end
end

mvdigamma(η,p) = sum( digamma(η + (one(d) - d)/2) for d=1:p)
5 changes: 5 additions & 0 deletions src/distributions/bernoulli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Bernoulli}) = (η) -> begin
return -log(logistic(-η₁))
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Bernoulli}) = (η) -> begin
(η₁,) = unpack_parameters(Bernoulli, η)
return SA[logistic(η₁)]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Bernoulli}) = (η) -> begin
(η₁,) = unpack_parameters(Bernoulli, η)
f = logistic(-η₁)
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ getlogpartition(::NaturalParametersSpace, ::Type{Beta}) = (η) -> begin
return logbeta(η₁ + one(η₁), η₂ + one(η₂))
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Beta}) = (η) -> begin
(η₁, η₂) = unpack_parameters(Beta, η)
η₁p = η₁ + one(η₁)
η₂p = η₂ + one(η₂)
ηsum = η₁p + η₂p
dig = digamma(ηsum)

return SA[digamma(η₁p) - dig, digamma(η₂p) - dig]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Beta}) = (η) -> begin
(η₁, η₂) = unpack_parameters(Beta, η)
psia = trigamma(η₁ + one(η₁))
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Binomial}, ntrials) = (η) -> b
return ntrials * log1pexp(η₁)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Binomial}, ntrials) = (η) -> begin
(η₁,) = unpack_parameters(Binomial, η)
return SA[ntrials*exp(η₁) / (one(η₁) + exp(η₁))]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Binomial}, ntrials) = (η) -> begin
(η₁,) = unpack_parameters(Binomial, η)
aux = logistic(η₁)
Expand Down
13 changes: 13 additions & 0 deletions src/distributions/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ getlogpartition(::NaturalParametersSpace, ::Type{Categorical}, conditioner) =
return logsumexp(η)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Categorical}, conditioner) =
(η) -> begin
if (conditioner !== length(η))
throw(
DimensionMismatch(
"Cannot evaluate the logparition of the `Categorical` with `conditioner = $(conditioner)` and vector of natural parameters `η = $(η)`"
)
)
end
sumη = vmapreduce(exp, +, η)
return vmap(d->exp(d)/sumη ,η)
end

getfisherinformation(::NaturalParametersSpace, ::Type{Categorical}, conditioner) =
(η) -> begin
if (conditioner !== length(η))
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/chi_squared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin
return loggamma(η1 + o) + (η1 + o) * logtwo
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin
(η1,) = unpack_parameters(Chisq, η)
return SA[digamma(η1 + one(η1)) + logtwo]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Chisq}) = (η) -> begin
(η1,) = unpack_parameters(Chisq, η)
SA[trigamma(η1 + one(η1));;]
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ getfisherinformation(::NaturalParametersSpace, ::Type{Dirichlet}) =
return Diagonal(map(d -> trigamma(d + 1), η1)) - Ones{Float64}(n, n) * trigamma(sum(η1) + n)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Dirichlet}) = (η) -> begin
(η1,) = unpack_parameters(Dirichlet, η)
return digamma.(η1 .+ 1) .- digamma(sum(η1 .+ 1))
end

# Mean parametrization

getlogpartition(::MeanParametersSpace, ::Type{Dirichlet}) = (θ) -> begin
Expand All @@ -83,3 +88,8 @@ getfisherinformation(::MeanParametersSpace, ::Type{Dirichlet}) = (θ) -> begin
n = length(α)
return Diagonal(map(d -> trigamma(d), α)) - Ones{Float64}(n, n) * trigamma(sum(α))
end

getgradlogpartition(::MeanParametersSpace, ::Type{Dirichlet}) = (θ) -> begin
(α,) = unpack_parameters(Dirichlet, θ)
return digamma.(α) .- digamma(sum(α))
end
5 changes: 5 additions & 0 deletions src/distributions/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
return -log(-η₁)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
(η₁,) = unpack_parameters(Exponential, η)
return SA[-1/η₁]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Exponential}) = (η) -> begin
(η₁,) = unpack_parameters(Exponential, η)
SA[inv(η₁^2);;]
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/gamma_family/gamma_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ getfisherinformation(::NaturalParametersSpace, ::Type{Gamma}) = (η) -> begin
SA[trigamma(η₁ + one(η₁)) -inv(η₂); -inv(η₂) (η₁+one(η₁))/(η₂^2)]
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Gamma}) = (η) -> begin
(η₁, η₂) = unpack_parameters(Gamma, η)
return SA[digamma(η₁ + one(η₁)) - log(-η₂), - (η₁ + one(η₁)) / η₂]
end

# Mean parametrization

getlogpartition(::MeanParametersSpace, ::Type{Gamma}) = (θ) -> begin
Expand All @@ -114,3 +119,8 @@ getfisherinformation(::MeanParametersSpace, ::Type{Gamma}) = (θ) -> begin
inv(scale) shape/abs2(scale)
]
end

getgradlogpartition(::MeanParametersSpace, ::Type{Gamma}) = (θ) -> begin
(shape, scale) = unpack_parameters(Gamma, θ)
return SA[digamma(shape) - log(scale), - shape / scale]
end
7 changes: 7 additions & 0 deletions src/distributions/gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ getlogpartition(::NaturalParametersSpace, ::Type{GammaInverse}) = (η) -> begin
return loggamma(-η₁ - one(η₁)) - (-η₁ - one(η₁)) * log(-η₂)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{GammaInverse}) = (η) -> begin
(η₁, η₂) = unpack_parameters(GammaInverse, η)
dη1 = -digamma(-η₁ - one(η₁)) + log(-η₂)
dη2 = - (-η₁ - one(η₁))/η₂
return SA[dη1, dη2]
end

getfisherinformation(::NaturalParametersSpace, ::Type{GammaInverse}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(GammaInverse, η)
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/geometric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ getfisherinformation(::NaturalParametersSpace, ::Type{Geometric}) = (η) -> begi
return SA[exp(η1) / (one(η1) - exp(η1))^2;;]
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Geometric}) = (η) -> begin
(η1,) = unpack_parameters(Geometric, η)
return SA[exp(η1) / (one(η1) - exp(η1));]
end

# Mean parametrization

getlogpartition(::MeanParametersSpace, ::Type{Geometric}) = (θ) -> begin
Expand All @@ -60,3 +65,8 @@ getfisherinformation(::MeanParametersSpace, ::Type{Geometric}) = (θ) -> begin
(p,) = unpack_parameters(Geometric, θ)
return SA[one(p) / (p^2 * (one(p) - p));;]
end

getgradlogpartition(::MeanParametersSpace, ::Type{Geometric}) = (θ) -> begin
(p,) = unpack_parameters(Geometric, θ)
return SA[one(p) / (p^2 - p);]
end
5 changes: 5 additions & 0 deletions src/distributions/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Laplace}, _) = (η) -> begin
return log(-2 / η₁)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Laplace}, _) = (η) -> begin
(η₁,) = unpack_parameters(Laplace, η)
return SA[-inv(η₁);]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Laplace}, _) = (η) -> begin
(η₁,) = unpack_parameters(Laplace, η)
return SA[inv(η₁^2);;]
Expand Down
14 changes: 14 additions & 0 deletions src/distributions/lognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ getlogpartition(::NaturalParametersSpace, ::Type{LogNormal}) = (η) -> begin
return -(η₁ + 1)^2 / (4η₂) - log(-2η₂) / 2
end

getgradlogpartition(::NaturalParametersSpace, ::Type{LogNormal}) = (η) -> begin
(η₁, η₂) = unpack_parameters(LogNormal, η)
dη1 = -(η₁ + 1)/(2η₂)
dη2 = (η₁ + 1)^2/(4η₂^2) - inv(η₂)/2
return SA[dη1, dη2]
end

getfisherinformation(::NaturalParametersSpace, ::Type{LogNormal}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(LogNormal, η)
Expand All @@ -66,6 +73,13 @@ getlogpartition(::MeanParametersSpace, ::Type{LogNormal}) = (θ) -> begin
return abs2(μ) / (2abs2(σ)) + log(σ)
end

getgradlogpartition(::MeanParametersSpace, ::Type{LogNormal}) = (θ) -> begin
(μ, σ) = unpack_parameters(LogNormal, θ)
dμ = abs(μ) / (abs2(σ))
dσ = -abs2(μ) / (σ^3) + 1/σ
return SA[dμ, dσ]
end

getfisherinformation(::MeanParametersSpace, ::Type{LogNormal}) = (θ) -> begin
(μ, σ) = unpack_parameters(LogNormal, θ)
invσ² = inv(abs2(σ))
Expand Down
16 changes: 16 additions & 0 deletions src/distributions/matrix_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ getlogpartition(::NaturalParametersSpace, ::Type{MatrixDirichlet}) =
)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MatrixDirichlet}) =
(η) -> begin
(η1,) = unpack_parameters(MatrixDirichlet, η)
return vmapreduce(
d -> getgradlogpartition(NaturalParametersSpace(), Dirichlet)(convert(Vector, d)),vcat,
eachcol(η1))
end

getfisherinformation(::NaturalParametersSpace, ::Type{MatrixDirichlet}) =
(η) -> begin
(η1,) = unpack_parameters(MatrixDirichlet, η)
Expand All @@ -177,6 +185,14 @@ getlogpartition(::MeanParametersSpace, ::Type{MatrixDirichlet}) =
)
end

getgradlogpartition(::MeanParametersSpace, ::Type{MatrixDirichlet}) =
(θ) -> begin
(α,) = unpack_parameters(MatrixDirichlet, θ)
return vmapreduce(
d -> getgradlogpartition(NaturalParametersSpace(), Dirichlet)(convert(Vector, d)),vcat,
eachcol(α))
end

getfisherinformation(::MeanParametersSpace, ::Type{MatrixDirichlet}) =
(θ) -> begin
(α,) = unpack_parameters(MatrixDirichlet, θ)
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/negative_binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{NegativeBinomial}, conditioner)
return -conditioner * log(one(η1) - exp(η1))
end

getgradlogpartition(::NaturalParametersSpace,::Type{NegativeBinomial}, conditioner) = (η) -> begin
(η1,) = unpack_parameters(NegativeBinomial, η)
return SA[-conditioner*(-exp(η1)/(one(η1)-exp(η1)));]
end

getfisherinformation(::NaturalParametersSpace, ::Type{NegativeBinomial}, r) = (η) -> begin
(η1,) = unpack_parameters(NegativeBinomial, η)
return SA[r * exp(η1) / (one(η1) - exp(η1))^2;;]
Expand All @@ -118,6 +123,11 @@ getlogpartition(::MeanParametersSpace, ::Type{NegativeBinomial}, conditioner) =
return -conditioner * log(one(p) - p)
end

getgradlogpartition(::MeanParametersSpace,::Type{NegativeBinomial}, conditioner) = (θ) -> begin
(p,) = unpack_parameters(NegativeBinomial, η)
return SA[conditioner*inv(one(p) - p);]
end

getfisherinformation(::MeanParametersSpace, ::Type{NegativeBinomial}, r) = (θ) -> begin
(p,) = unpack_parameters(NegativeBinomial, θ)
return SA[r / (p^2 * (one(p) - p));;]
Expand Down
19 changes: 19 additions & 0 deletions src/distributions/normal_family/normal_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,12 @@ getlogpartition(::NaturalParametersSpace, ::Type{NormalMeanVariance}) = (η) ->
return -abs2(η₁) / 4η₂ - log(-2η₂) / 2
end

getgradlogpartition(::NaturalParametersSpace, ::Type{NormalMeanVariance}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(NormalMeanVariance, η)
return SA[-η₁ * inv(η₂*2), abs2(η₁) / ( 4 * abs2(η₂)) - 1 / (2 * η₂)]
end

getfisherinformation(::NaturalParametersSpace, ::Type{NormalMeanVariance}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(NormalMeanVariance, η)
Expand All @@ -591,6 +597,12 @@ getlogpartition(::MeanParametersSpace, ::Type{NormalMeanVariance}) = (θ) -> beg
return μ / 2σ² + log(sqrt(σ²))
end

getgradlogpartition(::MeanParametersSpace, ::Type{NormalMeanVariance}) =
(θ) -> begin
(μ, σ²) = unpack_parameters(NormalMeanVariance, θ)
return SA[μ / σ², - abs2(μ) / (2σ²^2) + 1 / σ²]
end

getfisherinformation(::MeanParametersSpace, ::Type{NormalMeanVariance}) = (θ) -> begin
(_, σ²) = unpack_parameters(NormalMeanVariance, θ)
return SA[inv(σ²) 0; 0 inv(2 * (σ²^2))]
Expand Down Expand Up @@ -678,6 +690,13 @@ getlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) = (η)
return (dot(η₁, Cinv, η₁) / 2 - (k * log(2) + l)) / 2
end

getgradlogpartition(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(MvNormalMeanCovariance, η)
Cinv, _ = cholinv_logdet(-η₂)
return pack_parameters(MvNormalMeanCovariance, (0.5 * Cinv * η₁, 0.25 * Cinv * η₁ * η₁' * Cinv + 0.5 * Cinv))
end

getfisherinformation(::NaturalParametersSpace, ::Type{MvNormalMeanCovariance}) =
(η) -> begin
(η₁, η₂) = unpack_parameters(MvNormalMeanCovariance, η)
Expand Down
12 changes: 12 additions & 0 deletions src/distributions/normal_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ getlogpartition(::NaturalParametersSpace, ::Type{NormalGamma}) = (η) -> begin
return loggamma(η3half) - log(-2η2) * (1 / 2) - (η3half) * log(-η4 + η1^2 / (4η2))
end

getgradlogpartition(::NaturalParametersSpace,::Type{NormalGamma}) = (η) -> begin
(η1, η2, η3, η4) = unpack_parameters(NormalGamma, η)
η3half = η3 + (1 / 2)
c = (-η4 + η1^2/(4η2))
dη1 = -η3half*((η1/(2η2)) / c)
dη2 = -inv(η2)/2 - η3half*(-η1^2/(4η2^2) / c)
dη3 = digamma(η3half) - log(c)
dη4 = η3half /c

return SA[dη1, dη2, dη3, dη4]
end

getfisherinformation(::NaturalParametersSpace, ::Type{NormalGamma}) =
(η) -> begin
(η1, η2, η3, η4) = unpack_parameters(NormalGamma, η)
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Pareto}, conditioner) = (η) ->
return log(conditioner^(one(η1) + η1) / (-one(η1) - η1))
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Pareto}, conditioner) = (η) -> begin
(η1,) = unpack_parameters(Pareto, η)
return SA[log(conditioner) - one(η1)/(one(η1)+η1);]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Pareto}, _) = (η) -> begin
(η1,) = unpack_parameters(Pareto, η)
return SA[1 / (-1 - η1)^2;;]
Expand All @@ -160,6 +165,11 @@ getlogpartition(::MeanParametersSpace, ::Type{Pareto}, conditioner) = (θ) -> be
return -log(shape) - shape * log(conditioner)
end

getgradlogpartition(::MeanParametersSpace, ::Type{Pareto}, conditioner) = (θ) -> begin
(shape,) = unpack_parameters(Pareto, θ)
return SA[-inv(shape) - log(conditioner);]
end

getfisherinformation(::MeanParametersSpace, ::Type{Pareto}, conditioner) = (θ) -> begin
(α,) = unpack_parameters(Pareto, θ)
### Below fisher information is problematic if α is larger than conditioner as Pareto
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
return exp(η1)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
(η1,) = unpack_parameters(Poisson, η)
return SA[exp(η1)]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Poisson}) = (η) -> begin
(η1,) = unpack_parameters(Poisson, η)
SA[exp(η1);;]
Expand Down
10 changes: 10 additions & 0 deletions src/distributions/rayleigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ getlogpartition(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin
return -log(-2 * η1)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin
(η1, ) = unpack_parameters(Rayleigh, η)
return SA[-inv(η1);]
end

getfisherinformation(::NaturalParametersSpace, ::Type{Rayleigh}) = (η) -> begin
(η1,) = unpack_parameters(Rayleigh, η)
SA[inv(η1^2);;]
Expand All @@ -68,6 +73,11 @@ getlogpartition(::MeanParametersSpace, ::Type{Rayleigh}) = (θ) -> begin
return 2 * log(σ)
end

getgradlogpartition(::MeanParametersSpace, ::Type{Rayleigh}) = (θ) -> begin
(σ,) = unpack_parameters(Rayleigh, θ)
return SA[2/σ;]
end

getfisherinformation(::MeanParametersSpace, ::Type{Rayleigh}) = (θ) -> begin
(σ,) = unpack_parameters(Rayleigh, θ)
return SA[4 / σ^2;;]
Expand Down
Loading
Loading