Skip to content

Commit

Permalink
Merge pull request #175 from biaslab/kurtosis_skewness
Browse files Browse the repository at this point in the history
Add: kurtosis and skewness, Fix: piracy to piracies
  • Loading branch information
bvdmitri authored Jan 9, 2024
2 parents ae78f1e + 3b49308 commit 43e0691
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 31 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExponentialFamily"
uuid = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
authors = ["Ismail Senoz <i.senoz@tue.nl>", "Dmitry Bagaev <d.v.bagaev@tue.nl>"]
version = "1.2.2"
version = "1.3.0"

[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
Expand All @@ -28,7 +28,7 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"

[compat]
Aqua = "0.7"
BayesBase = "1.1"
BayesBase = "1.2"
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
FastCholesky = "1.0"
Expand Down
16 changes: 9 additions & 7 deletions src/distributions/gamma_family/gamma_shape_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ GammaShapeRate() = GammaShapeRate(1.0, 1.0)

Distributions.@distr_support GammaShapeRate 0 Inf

BayesBase.support(dist::GammaShapeRate) = Distributions.RealInterval(minimum(dist), maximum(dist))
BayesBase.shape(dist::GammaShapeRate) = dist.a
BayesBase.rate(dist::GammaShapeRate) = dist.b
BayesBase.scale(dist::GammaShapeRate) = inv(dist.b)
BayesBase.mean(dist::GammaShapeRate) = shape(dist) / rate(dist)
BayesBase.var(dist::GammaShapeRate) = shape(dist) / abs2(rate(dist))
BayesBase.params(dist::GammaShapeRate) = (shape(dist), rate(dist))
BayesBase.support(dist::GammaShapeRate) = Distributions.RealInterval(minimum(dist), maximum(dist))
BayesBase.shape(dist::GammaShapeRate) = dist.a
BayesBase.rate(dist::GammaShapeRate) = dist.b
BayesBase.scale(dist::GammaShapeRate) = inv(dist.b)
BayesBase.mean(dist::GammaShapeRate) = shape(dist) / rate(dist)
BayesBase.var(dist::GammaShapeRate) = shape(dist) / abs2(rate(dist))
BayesBase.params(dist::GammaShapeRate) = (shape(dist), rate(dist))
BayesBase.kurtosis(dist::GammaShapeRate) = kurtosis(convert(Gamma, dist))
BayesBase.skewness(dist::GammaShapeRate) = skewness(convert(Gamma, dist))

BayesBase.mode(d::GammaShapeRate) =
shape(d) >= 1 ? mode(Gamma(shape(d), scale(d))) : throw(error("Gamma has no mode when shape < 1"))
Expand Down
18 changes: 10 additions & 8 deletions src/distributions/normal_family/normal_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@ BayesBase.support(dist::NormalMeanPrecision) = Distributions.RealInterval(minimu

BayesBase.weightedmean(dist::NormalMeanPrecision) = precision(dist) * mean(dist)

BayesBase.mean(dist::NormalMeanPrecision) = dist.μ
BayesBase.median(dist::NormalMeanPrecision) = mean(dist)
BayesBase.mode(dist::NormalMeanPrecision) = mean(dist)
BayesBase.var(dist::NormalMeanPrecision) = inv(dist.w)
BayesBase.std(dist::NormalMeanPrecision) = sqrt(var(dist))
BayesBase.cov(dist::NormalMeanPrecision) = var(dist)
BayesBase.invcov(dist::NormalMeanPrecision) = dist.w
BayesBase.mean(dist::NormalMeanPrecision) = dist.μ
BayesBase.median(dist::NormalMeanPrecision) = mean(dist)
BayesBase.mode(dist::NormalMeanPrecision) = mean(dist)
BayesBase.var(dist::NormalMeanPrecision) = inv(dist.w)
BayesBase.std(dist::NormalMeanPrecision) = sqrt(var(dist))
BayesBase.cov(dist::NormalMeanPrecision) = var(dist)
BayesBase.invcov(dist::NormalMeanPrecision) = dist.w
BayesBase.entropy(dist::NormalMeanPrecision) = (1 + log2π - log(precision(dist))) / 2
BayesBase.params(dist::NormalMeanPrecision) = (mean(dist), precision(dist))
BayesBase.params(dist::NormalMeanPrecision) = (mean(dist), precision(dist))
BayesBase.kurtosis(dist::NormalMeanPrecision) = kurtosis(convert(Normal, dist))
BayesBase.skewness(dist::NormalMeanPrecision) = skewness(convert(Normal, dist))

BayesBase.pdf(dist::NormalMeanPrecision, x::Real) = (invsqrt2π * exp(-abs2(x - mean(dist)) * precision(dist) / 2)) * sqrt(precision(dist))
BayesBase.logpdf(dist::NormalMeanPrecision, x::Real) = -(log2π - log(precision(dist)) + abs2(x - mean(dist)) * precision(dist)) / 2
Expand Down
21 changes: 12 additions & 9 deletions src/distributions/normal_family/normal_mean_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ function BayesBase.weightedmean_invcov(dist::NormalMeanVariance)
return (xi, w)
end

BayesBase.mean(dist::NormalMeanVariance) = dist.μ
BayesBase.median(dist::NormalMeanVariance) = mean(dist)
BayesBase.mode(dist::NormalMeanVariance) = mean(dist)
BayesBase.var(dist::NormalMeanVariance) = dist.v
BayesBase.std(dist::NormalMeanVariance) = sqrt(var(dist))
BayesBase.cov(dist::NormalMeanVariance) = var(dist)
BayesBase.invcov(dist::NormalMeanVariance) = inv(cov(dist))
BayesBase.entropy(dist::NormalMeanVariance) = (1 + log2π + log(var(dist))) / 2
BayesBase.params(dist::NormalMeanVariance) = (dist.μ, dist.v)
BayesBase.mean(dist::NormalMeanVariance) = dist.μ
BayesBase.median(dist::NormalMeanVariance) = mean(dist)
BayesBase.mode(dist::NormalMeanVariance) = mean(dist)
BayesBase.var(dist::NormalMeanVariance) = dist.v
BayesBase.std(dist::NormalMeanVariance) = sqrt(var(dist))
BayesBase.cov(dist::NormalMeanVariance) = var(dist)
BayesBase.invcov(dist::NormalMeanVariance) = inv(cov(dist))
BayesBase.entropy(dist::NormalMeanVariance) = (1 + log2π + log(var(dist))) / 2
BayesBase.params(dist::NormalMeanVariance) = (dist.μ, dist.v)
BayesBase.kurtosis(dist::NormalMeanVariance) = kurtosis(convert(Normal, dist))
BayesBase.skewness(dist::NormalMeanVariance) = skewness(convert(Normal, dist))

BayesBase.pdf(dist::NormalMeanVariance, x::Real) = (invsqrt2π * exp(-abs2(x - mean(dist)) / 2cov(dist))) / std(dist)
BayesBase.logpdf(dist::NormalMeanVariance, x::Real) = -(log2π + log(var(dist)) + abs2(x - mean(dist)) / var(dist)) / 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ BayesBase.cov(dist::NormalWeightedMeanPrecision) = var(dist)
BayesBase.invcov(dist::NormalWeightedMeanPrecision) = dist.w
BayesBase.entropy(dist::NormalWeightedMeanPrecision) = (1 + log2π - log(precision(dist))) / 2
BayesBase.params(dist::NormalWeightedMeanPrecision) = (weightedmean(dist), precision(dist))
BayesBase.kurtosis(dist::NormalWeightedMeanPrecision) = kurtosis(convert(Normal, dist))
BayesBase.skewness(dist::NormalWeightedMeanPrecision) = skewness(convert(Normal, dist))
BayesBase.pdf(dist::NormalWeightedMeanPrecision, x::Real) = (invsqrt2π * exp(-abs2(x - mean(dist)) * precision(dist) / 2)) * sqrt(precision(dist))
BayesBase.logpdf(dist::NormalWeightedMeanPrecision, x::Real) = -(log2π - log(precision(dist)) + abs2(x - mean(dist)) * precision(dist)) / 2

Expand Down
2 changes: 2 additions & 0 deletions src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,8 @@ BayesBase.mean(ef::ExponentialFamilyDistribution{T}) where {T <: Distribution} =
BayesBase.var(ef::ExponentialFamilyDistribution{T}) where {T <: Distribution} = var(convert(T, ef))
BayesBase.std(ef::ExponentialFamilyDistribution{T}) where {T <: Distribution} = std(convert(T, ef))
BayesBase.cov(ef::ExponentialFamilyDistribution{T}) where {T <: Distribution} = cov(convert(T, ef))
BayesBase.skewness(ef::ExponentialFamilyDistribution{T}) where {T <: Distribution} = skewness(convert(T, ef))
BayesBase.kurtosis(ef::ExponentialFamilyDistribution{T}) where {T <: Distribution} = kurtosis(convert(T, ef))

BayesBase.rand(ef::ExponentialFamilyDistribution, args...) = rand(Random.default_rng(), ef, args...)
BayesBase.rand!(ef::ExponentialFamilyDistribution, args...) = rand!(Random.default_rng(), ef, args...)
Expand Down
19 changes: 18 additions & 1 deletion test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function test_exponentialfamily_interface(distribution;
test_fisherinformation_properties = true,
test_fisherinformation_against_hessian = true,
test_fisherinformation_against_jacobian = true,
option_assume_no_allocations = false,
option_assume_no_allocations = false
)
T = ExponentialFamily.exponential_family_typetag(distribution)

Expand Down Expand Up @@ -203,6 +203,17 @@ function run_test_basic_functions(distribution; nsamples = 10, test_gradients =
# ! do not use fixed RNG
samples = [rand(distribution) for _ in 1:nsamples]

# Not all methods are defined for all objects in Distributions.jl
# For this methods we first test if the method is defined for the distribution
# And only then we test the method for the exponential family form
potentially_missing_methods = (
cov,
skewness,
kurtosis
)

argument_type = Tuple{typeof(distribution)}

for x in samples
# We believe in the implementation in the `Distributions.jl`
@test @inferred(logpdf(ef, x)) logpdf(distribution, x)
Expand All @@ -214,6 +225,12 @@ function run_test_basic_functions(distribution; nsamples = 10, test_gradients =
@test all(rand(StableRNG(42), ef, 10) .≈ rand(StableRNG(42), distribution, 10))
@test all(rand!(StableRNG(42), ef, [deepcopy(x) for _ in 1:10]) .≈ rand!(StableRNG(42), distribution, [deepcopy(x) for _ in 1:10]))

for method in potentially_missing_methods
if hasmethod(method, argument_type)
@test @inferred(method(ef)) method(distribution)
end
end

@test @inferred(isbasemeasureconstant(ef)) === isbasemeasureconstant(T)
@test @inferred(basemeasure(ef, x)) == getbasemeasure(T, conditioner)(x)
@test all(@inferred(sufficientstatistics(ef, x)) .== map(f -> f(x), getsufficientstatistics(T, conditioner)))
Expand Down
2 changes: 1 addition & 1 deletion test/distributions/gamma_inverse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
@testitem "GammaInverse: ExponentialFamilyDistribution" begin
include("distributions_setuptests.jl")

for α in 10rand(4), θ in 10rand(4)
for α in (10rand(4) .+ 4.0), θ in 10rand(4)
@testset let d = InverseGamma(α, θ)
ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true)

Expand Down
4 changes: 2 additions & 2 deletions test/distributions/mv_normal_wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
@testitem "MvNormalWishart: ExponentialFamilyDistribution" begin
include("distributions_setuptests.jl")

for dim in (3), invS in rand(Wishart(10, Array(Eye(dim))), 4)
for dim in (3,), invS in rand(Wishart(10, Array(Eye(dim))), 4)
ν = dim + 2
@testset let (d = MvNormalWishart(rand(dim), invS, rand(), ν))
ef = test_exponentialfamily_interface(
Expand All @@ -25,7 +25,7 @@ end
test_fisherinformation_against_jacobian = false
)

run_test_basic_functions(ef; assume_no_allocations = false, test_samples_logpdf = false)
run_test_basic_functions(d; assume_no_allocations = false, test_samples_logpdf = false)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/distributions/pareto_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@testitem "Pareto: ExponentialFamilyDistribution" begin
include("distributions_setuptests.jl")

for shape in (1.0, 2.0, 3.0), scale in (0.25, 0.5, 2.0)
for shape in (5.0, 6.0, 7.0), scale in (0.25, 0.5, 2.0)
@testset let d = Pareto(shape, scale)
ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false)
η1 = -shape - 1
Expand Down

0 comments on commit 43e0691

Please sign in to comment.