diff --git a/Project.toml b/Project.toml index 1d8d8950..9e8c19c2 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] -Aqua = "0.7" +Aqua = "0.8.7" BayesBase = "1.2" Distributions = "0.25" DomainSets = "0.5.2, 0.6, 0.7" diff --git a/src/distributions/poisson.jl b/src/distributions/poisson.jl index 34625973..af2cd398 100644 --- a/src/distributions/poisson.jl +++ b/src/distributions/poisson.jl @@ -52,7 +52,7 @@ unpack_parameters(::Type{Poisson}, packed) = (first(packed),) isbasemeasureconstant(::Type{Poisson}) = NonConstantBaseMeasure() -getbasemeasure(::Type{Poisson}) = (x) -> one(x) / factorial(x) +getbasemeasure(::Type{Poisson}) = (x) -> one(x) / gamma(x + one(x)) getlogbasemeasure(::Type{Poisson}) = (x) -> -loggamma(x + one(x)) getsufficientstatistics(::Type{Poisson}) = (identity,) diff --git a/src/distributions/von_mises_fisher.jl b/src/distributions/von_mises_fisher.jl index c2f05445..7f4f71b7 100644 --- a/src/distributions/von_mises_fisher.jl +++ b/src/distributions/von_mises_fisher.jl @@ -37,7 +37,7 @@ end BayesBase.var(dist::VonMisesFisher) = diag(cov(dist)) BayesBase.std(dist::VonMisesFisher) = sqrt.(var(dist)) -function BayesBase.insupport(ef::ExponentialFamilyDistribution{VonMisesFisher}, x::Vector) +function BayesBase.insupport(ef::ExponentialFamilyDistribution{VonMisesFisher}, x) return length(getnaturalparameters(ef)) == length(x) && Distributions.isunitvec(x) end diff --git a/src/exponential_family.jl b/src/exponential_family.jl index cfdefa1b..f635b826 100644 --- a/src/exponential_family.jl +++ b/src/exponential_family.jl @@ -620,7 +620,7 @@ end function _plogpdf(ef, x) @assert insupport(ef, x) lazy"Point $(x) does not belong to the support of $(ef)" - return _plogpdf(ef, x, logpartition(ef)) + return _plogpdf(ef, x, logpartition(ef), logbasemeasure(ef,x)) end _scalarproduct(::Type{T}, η, statistics) where {T} = _scalarproduct(variate_form(T), T, η, statistics) @@ -628,13 +628,13 @@ _scalarproduct(::Type{Univariate}, η, statistics) = dot(η, flatten_parameters( _scalarproduct(::Type{Univariate}, ::Type{T}, η, statistics) where {T} = dot(η, flatten_parameters(T, statistics)) _scalarproduct(_, ::Type{T}, η, statistics) where {T} = dot(η, pack_parameters(T, statistics)) -function _plogpdf(ef::ExponentialFamilyDistribution{T}, x, logpartition) where {T} + +function _plogpdf(ef::ExponentialFamilyDistribution{T}, x, logpartition, logbasemeasure) where {T} # TODO: Think of what to do with this assert @assert insupport(ef, x) lazy"Point $(x) does not belong to the support of $(ef)" η = getnaturalparameters(ef) _statistics = sufficientstatistics(ef, x) - _basemeasure = basemeasure(ef, x) - return log(_basemeasure) + _scalarproduct(T, η, _statistics) - logpartition + return logbasemeasure + _scalarproduct(T, η, _statistics) - logpartition end """ @@ -683,7 +683,7 @@ check_logpdf(::Type{Matrixvariate}, ::Type{<:AbstractMatrix}, ::Type{<:Number}, function _vlogpdf(ef, container) _logpartition = logpartition(ef) - return map(x -> _plogpdf(ef, x, _logpartition), container) + return map(x -> _plogpdf(ef, x, _logpartition, logbasemeasure(ef,x)), container) end check_logpdf(::Type{Univariate}, ::Type{<:AbstractVector}, ::Type{<:Number}, ef, container) = (MapBasedLogpdfCall(), container) diff --git a/test/distributions/distributions_setuptests.jl b/test/distributions/distributions_setuptests.jl index fa97d2b9..514cc099 100644 --- a/test/distributions/distributions_setuptests.jl +++ b/test/distributions/distributions_setuptests.jl @@ -1,5 +1,5 @@ using ExponentialFamily, BayesBase, FastCholesky, Distributions, LinearAlgebra, TinyHugeNumbers -using Test, ForwardDiff, Random, StatsFuns, StableRNGs, FillArrays, JET +using Test, ForwardDiff, Random, StatsFuns, StableRNGs, FillArrays, JET, SpecialFunctions import BayesBase: compute_logscale @@ -67,6 +67,7 @@ function test_exponentialfamily_interface(distribution; test_fisherinformation_properties = true, test_fisherinformation_against_hessian = true, test_fisherinformation_against_jacobian = true, + test_plogpdf_interface = true, option_assume_no_allocations = false ) T = ExponentialFamily.exponential_family_typetag(distribution) @@ -87,10 +88,20 @@ function test_exponentialfamily_interface(distribution; test_fisherinformation_properties && run_test_fisherinformation_properties(distribution) test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations) test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations) - + test_plogpdf_interface && run_test_plogpdf_interface(distribution) return ef end +function run_test_plogpdf_interface(distribution) + ef = convert(ExponentialFamily.ExponentialFamilyDistribution, distribution) + η = getnaturalparameters(ef) + samples = rand(StableRNG(42), distribution, 10) + _, _samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples) + ss_vectors = map(s -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), _samples) + unnormalized_logpdfs = map(v -> dot(v, η), ss_vectors) + @test all(unnormalized_logpdfs ≈ map(x -> ExponentialFamily._plogpdf(ef, x, 0, 0), _samples)) +end + function run_test_parameters_conversion(distribution) T = ExponentialFamily.exponential_family_typetag(distribution) diff --git a/test/distributions/mv_normal_wishart_tests.jl b/test/distributions/mv_normal_wishart_tests.jl index 5fa03c57..f20a3b24 100644 --- a/test/distributions/mv_normal_wishart_tests.jl +++ b/test/distributions/mv_normal_wishart_tests.jl @@ -23,7 +23,8 @@ end test_basic_functions = false, test_fisherinformation_against_hessian = false, test_fisherinformation_against_jacobian = false, - test_gradlogpartition_properties = false + test_gradlogpartition_properties = false, + test_plogpdf_interface = false ) run_test_basic_functions(d; assume_no_allocations = false, test_samples_logpdf = false) diff --git a/test/distributions/poisson_tests.jl b/test/distributions/poisson_tests.jl index 78c06dc7..29b8fb6f 100644 --- a/test/distributions/poisson_tests.jl +++ b/test/distributions/poisson_tests.jl @@ -5,7 +5,7 @@ @testitem "Poisson: ExponentialFamilyDistribution" begin include("distributions_setuptests.jl") - @testset for i in 2:4 + @testset for i in 2:7 @testset let d = Poisson(2 * (i + 1)) ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true) η1 = first(getnaturalparameters(ef)) @@ -40,7 +40,7 @@ end prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right) sample_points = collect(1:5) for x in sample_points - @test basemeasure(prod_dist, x) == (1 / factorial(x)^2) + @test basemeasure(prod_dist, x) == (1 / gamma(x + one(x))^2) @test sufficientstatistics(prod_dist, x) == (x,) end sample_points = [-5, -2, 0, 2, 5] diff --git a/test/exponential_family_setuptests.jl b/test/exponential_family_setuptests.jl index 3955a818..ebf32e08 100644 --- a/test/exponential_family_setuptests.jl +++ b/test/exponential_family_setuptests.jl @@ -1,8 +1,8 @@ using ExponentialFamily, BayesBase, Distributions, Test, StatsFuns, BenchmarkTools, Random, FillArrays import Distributions: RealInterval, ContinuousUnivariateDistribution, Univariate -import ExponentialFamily: basemeasure, sufficientstatistics, logpartition, insupport, ConstantBaseMeasure -import ExponentialFamily: getnaturalparameters, getbasemeasure, getsufficientstatistics, getlogpartition, getsupport +import ExponentialFamily: basemeasure, logbasemeasure, sufficientstatistics, logpartition, insupport, ConstantBaseMeasure +import ExponentialFamily: getnaturalparameters, getbasemeasure, getlogbasemeasure, getsufficientstatistics, getlogpartition, getsupport import ExponentialFamily: ExponentialFamilyDistributionAttributes, NaturalParametersSpace # import ExponentialFamily: @@ -28,6 +28,7 @@ end ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}, η, conditioner) = isnothing(conditioner) ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryDistributionFromExponentialFamily}) = ConstantBaseMeasure() ExponentialFamily.getbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> oneunit(x) +ExponentialFamily.getlogbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> zero(x) ExponentialFamily.getsufficientstatistics(::Type{ArbitraryDistributionFromExponentialFamily}) = ((x) -> x, (x) -> log(x)) ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}) = (η) -> 1 / sum(η) @@ -52,6 +53,7 @@ end ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η, conditioner) = isinteger(conditioner) ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = NonConstantBaseMeasure() ExponentialFamily.getbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> x^conditioner +ExponentialFamily.getlogbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> conditioner*log(x) ExponentialFamily.getsufficientstatistics(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = ((x) -> log(x - conditioner),) ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = diff --git a/test/exponential_family_tests.jl b/test/exponential_family_tests.jl index 17c6f32c..b70f77c3 100644 --- a/test/exponential_family_tests.jl +++ b/test/exponential_family_tests.jl @@ -128,6 +128,10 @@ end @test @inferred(getbasemeasure(member)(2.0)) ≈ 1.0 @test @inferred(getbasemeasure(member)(4.0)) ≈ 1.0 + @test @inferred(logbasemeasure(member, 2.0)) ≈ log(1.0) + @test @inferred(getlogbasemeasure(member)(2.0)) ≈ log(1.0) + @test @inferred(getlogbasemeasure(member)(4.0)) ≈ log(1.0) + @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (2.0, log(2.0))) @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (2.0, log(2.0))) @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (4.0, log(4.0))) @@ -206,6 +210,10 @@ end @test @inferred(getbasemeasure(member)(2.0)) ≈ 2.0^-2 @test @inferred(getbasemeasure(member)(4.0)) ≈ 4.0^-2 + @test @inferred(logbasemeasure(member, 2.0)) ≈ -2*log(2.0) + @test @inferred(getlogbasemeasure(member)(2.0)) ≈ -2*log(2.0) + @test @inferred(getlogbasemeasure(member)(4.0)) ≈ -2*log(4.0) + @test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (log(2.0 + 2),)) @test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (log(2.0 + 2),)) @test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (log(4.0 + 2),)) diff --git a/test/runtests.jl b/test/runtests.jl index 8776bbc9..db294694 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,8 @@ using Aqua, CpuId, ReTestItems, ExponentialFamily # `ambiguities = false` - there are quite some ambiguities, but these should be normal and should not be encountered under normal circumstances -# `piracy = false` - we extend/add some of the methods to the objects defined in the Distributions.jl -Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracy = false) +# `piracies = false` - we extend/add some of the methods to the objects defined in the Distributions.jl +Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracies = false) nthreads = max(cputhreads(), 1) ncores = max(cpucores(), 1)