Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change _plogpdf function to accept four arguments #199

Merged
merged 14 commits into from
Jul 5, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"

[compat]
Aqua = "0.7"
Aqua = "0.8.7"
BayesBase = "1.2"
Distributions = "0.25"
DomainSets = "0.5.2, 0.6, 0.7"
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ unpack_parameters(::Type{Poisson}, packed) = (first(packed),)

isbasemeasureconstant(::Type{Poisson}) = NonConstantBaseMeasure()

getbasemeasure(::Type{Poisson}) = (x) -> one(x) / factorial(x)
getbasemeasure(::Type{Poisson}) = (x) -> one(x) / gamma(x + one(x))
getlogbasemeasure(::Type{Poisson}) = (x) -> -loggamma(x + one(x))
getsufficientstatistics(::Type{Poisson}) = (identity,)

Expand Down
2 changes: 1 addition & 1 deletion src/distributions/von_mises_fisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
BayesBase.var(dist::VonMisesFisher) = diag(cov(dist))
BayesBase.std(dist::VonMisesFisher) = sqrt.(var(dist))

function BayesBase.insupport(ef::ExponentialFamilyDistribution{VonMisesFisher}, x::Vector)
function BayesBase.insupport(ef::ExponentialFamilyDistribution{VonMisesFisher}, x)
return length(getnaturalparameters(ef)) == length(x) && Distributions.isunitvec(x)
end

Expand Down
10 changes: 5 additions & 5 deletions src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -620,21 +620,21 @@ end

function _plogpdf(ef, x)
@assert insupport(ef, x) lazy"Point $(x) does not belong to the support of $(ef)"
return _plogpdf(ef, x, logpartition(ef))
return _plogpdf(ef, x, logpartition(ef), logbasemeasure(ef,x))
end

_scalarproduct(::Type{T}, η, statistics) where {T} = _scalarproduct(variate_form(T), T, η, statistics)
_scalarproduct(::Type{Univariate}, η, statistics) = dot(η, flatten_parameters(statistics))
_scalarproduct(::Type{Univariate}, ::Type{T}, η, statistics) where {T} = dot(η, flatten_parameters(T, statistics))
_scalarproduct(_, ::Type{T}, η, statistics) where {T} = dot(η, pack_parameters(T, statistics))

function _plogpdf(ef::ExponentialFamilyDistribution{T}, x, logpartition) where {T}

function _plogpdf(ef::ExponentialFamilyDistribution{T}, x, logpartition, logbasemeasure) where {T}
# TODO: Think of what to do with this assert
@assert insupport(ef, x) lazy"Point $(x) does not belong to the support of $(ef)"
η = getnaturalparameters(ef)
_statistics = sufficientstatistics(ef, x)
_basemeasure = basemeasure(ef, x)
return log(_basemeasure) + _scalarproduct(T, η, _statistics) - logpartition
return logbasemeasure + _scalarproduct(T, η, _statistics) - logpartition
end

"""
Expand Down Expand Up @@ -683,7 +683,7 @@ check_logpdf(::Type{Matrixvariate}, ::Type{<:AbstractMatrix}, ::Type{<:Number},

function _vlogpdf(ef, container)
_logpartition = logpartition(ef)
return map(x -> _plogpdf(ef, x, _logpartition), container)
return map(x -> _plogpdf(ef, x, _logpartition, logbasemeasure(ef,x)), container)
end

check_logpdf(::Type{Univariate}, ::Type{<:AbstractVector}, ::Type{<:Number}, ef, container) = (MapBasedLogpdfCall(), container)
Expand Down
15 changes: 13 additions & 2 deletions test/distributions/distributions_setuptests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ExponentialFamily, BayesBase, FastCholesky, Distributions, LinearAlgebra, TinyHugeNumbers
using Test, ForwardDiff, Random, StatsFuns, StableRNGs, FillArrays, JET
using Test, ForwardDiff, Random, StatsFuns, StableRNGs, FillArrays, JET, SpecialFunctions

import BayesBase: compute_logscale

Expand Down Expand Up @@ -67,6 +67,7 @@ function test_exponentialfamily_interface(distribution;
test_fisherinformation_properties = true,
test_fisherinformation_against_hessian = true,
test_fisherinformation_against_jacobian = true,
test_plogpdf_interface = true,
option_assume_no_allocations = false
)
T = ExponentialFamily.exponential_family_typetag(distribution)
Expand All @@ -87,10 +88,20 @@ function test_exponentialfamily_interface(distribution;
test_fisherinformation_properties && run_test_fisherinformation_properties(distribution)
test_fisherinformation_against_hessian && run_test_fisherinformation_against_hessian(distribution; assume_no_allocations = option_assume_no_allocations)
test_fisherinformation_against_jacobian && run_test_fisherinformation_against_jacobian(distribution; assume_no_allocations = option_assume_no_allocations)

test_plogpdf_interface && run_test_plogpdf_interface(distribution)
return ef
end

function run_test_plogpdf_interface(distribution)
ef = convert(ExponentialFamily.ExponentialFamilyDistribution, distribution)
η = getnaturalparameters(ef)
samples = rand(StableRNG(42), distribution, 10)
_, _samples = ExponentialFamily.check_logpdf(variate_form(typeof(ef)), typeof(samples), eltype(samples), ef, samples)
ss_vectors = map(s -> ExponentialFamily.pack_parameters(ExponentialFamily.sufficientstatistics(ef, s)), _samples)
unnormalized_logpdfs = map(v -> dot(v, η), ss_vectors)
@test all(unnormalized_logpdfs ≈ map(x -> ExponentialFamily._plogpdf(ef, x, 0, 0), _samples))
end

function run_test_parameters_conversion(distribution)
T = ExponentialFamily.exponential_family_typetag(distribution)

Expand Down
3 changes: 2 additions & 1 deletion test/distributions/mv_normal_wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ end
test_basic_functions = false,
test_fisherinformation_against_hessian = false,
test_fisherinformation_against_jacobian = false,
test_gradlogpartition_properties = false
test_gradlogpartition_properties = false,
test_plogpdf_interface = false
)

run_test_basic_functions(d; assume_no_allocations = false, test_samples_logpdf = false)
Expand Down
4 changes: 2 additions & 2 deletions test/distributions/poisson_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@testitem "Poisson: ExponentialFamilyDistribution" begin
include("distributions_setuptests.jl")

@testset for i in 2:4
@testset for i in 2:7
@testset let d = Poisson(2 * (i + 1))
ef = test_exponentialfamily_interface(d; option_assume_no_allocations = true)
η1 = first(getnaturalparameters(ef))
Expand Down Expand Up @@ -40,7 +40,7 @@ end
prod_dist = prod(PreserveTypeProd(ExponentialFamilyDistribution), left, right)
sample_points = collect(1:5)
for x in sample_points
@test basemeasure(prod_dist, x) == (1 / factorial(x)^2)
@test basemeasure(prod_dist, x) == (1 / gamma(x + one(x))^2)
@test sufficientstatistics(prod_dist, x) == (x,)
end
sample_points = [-5, -2, 0, 2, 5]
Expand Down
6 changes: 4 additions & 2 deletions test/exponential_family_setuptests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using ExponentialFamily, BayesBase, Distributions, Test, StatsFuns, BenchmarkTools, Random, FillArrays

import Distributions: RealInterval, ContinuousUnivariateDistribution, Univariate
import ExponentialFamily: basemeasure, sufficientstatistics, logpartition, insupport, ConstantBaseMeasure
import ExponentialFamily: getnaturalparameters, getbasemeasure, getsufficientstatistics, getlogpartition, getsupport
import ExponentialFamily: basemeasure, logbasemeasure, sufficientstatistics, logpartition, insupport, ConstantBaseMeasure
import ExponentialFamily: getnaturalparameters, getbasemeasure, getlogbasemeasure, getsufficientstatistics, getlogpartition, getsupport
import ExponentialFamily: ExponentialFamilyDistributionAttributes, NaturalParametersSpace

# import ExponentialFamily:
Expand All @@ -28,6 +28,7 @@ end
ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}, η, conditioner) = isnothing(conditioner)
ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryDistributionFromExponentialFamily}) = ConstantBaseMeasure()
ExponentialFamily.getbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> oneunit(x)
ExponentialFamily.getlogbasemeasure(::Type{ArbitraryDistributionFromExponentialFamily}) = (x) -> zero(x)
ExponentialFamily.getsufficientstatistics(::Type{ArbitraryDistributionFromExponentialFamily}) =
((x) -> x, (x) -> log(x))
ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryDistributionFromExponentialFamily}) = (η) -> 1 / sum(η)
Expand All @@ -52,6 +53,7 @@ end
ExponentialFamily.isproper(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, η, conditioner) = isinteger(conditioner)
ExponentialFamily.isbasemeasureconstant(::Type{ArbitraryConditionedDistributionFromExponentialFamily}) = NonConstantBaseMeasure()
ExponentialFamily.getbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> x^conditioner
ExponentialFamily.getlogbasemeasure(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) = (x) -> conditioner*log(x)
ExponentialFamily.getsufficientstatistics(::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) =
((x) -> log(x - conditioner),)
ExponentialFamily.getlogpartition(::NaturalParametersSpace, ::Type{ArbitraryConditionedDistributionFromExponentialFamily}, conditioner) =
Expand Down
8 changes: 8 additions & 0 deletions test/exponential_family_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ end
@test @inferred(getbasemeasure(member)(2.0)) ≈ 1.0
@test @inferred(getbasemeasure(member)(4.0)) ≈ 1.0

@test @inferred(logbasemeasure(member, 2.0)) ≈ log(1.0)
@test @inferred(getlogbasemeasure(member)(2.0)) ≈ log(1.0)
@test @inferred(getlogbasemeasure(member)(4.0)) ≈ log(1.0)

@test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (2.0, log(2.0)))
@test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (2.0, log(2.0)))
@test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (4.0, log(4.0)))
Expand Down Expand Up @@ -206,6 +210,10 @@ end
@test @inferred(getbasemeasure(member)(2.0)) ≈ 2.0^-2
@test @inferred(getbasemeasure(member)(4.0)) ≈ 4.0^-2

@test @inferred(logbasemeasure(member, 2.0)) ≈ -2*log(2.0)
@test @inferred(getlogbasemeasure(member)(2.0)) ≈ -2*log(2.0)
@test @inferred(getlogbasemeasure(member)(4.0)) ≈ -2*log(4.0)

@test all(@inferred(sufficientstatistics(member, 2.0)) .≈ (log(2.0 + 2),))
@test all(@inferred(map(f -> f(2.0), getsufficientstatistics(member))) .≈ (log(2.0 + 2),))
@test all(@inferred(map(f -> f(4.0), getsufficientstatistics(member))) .≈ (log(4.0 + 2),))
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Aqua, CpuId, ReTestItems, ExponentialFamily

# `ambiguities = false` - there are quite some ambiguities, but these should be normal and should not be encountered under normal circumstances
# `piracy = false` - we extend/add some of the methods to the objects defined in the Distributions.jl
Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracy = false)
# `piracies = false` - we extend/add some of the methods to the objects defined in the Distributions.jl
Aqua.test_all(ExponentialFamily, ambiguities = false, deps_compat = (; check_extras = false, check_weakdeps = true), piracies = false)

nthreads = max(cputhreads(), 1)
ncores = max(cpucores(), 1)
Expand Down
Loading