From 2aba23dd91cdbdcf87a2305dc5cde3e5d057be55 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 13:25:30 +0100 Subject: [PATCH] Speedup logpdf --- src/distributions/tensor_dirichlet.jl | 10 +++++-- test/distributions/tensor_dirichlet_test.jl | 32 ++++++++------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index 5c039be2..f9efd689 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -4,6 +4,7 @@ import SpecialFunctions: digamma, loggamma import Base: eltype import Distributions: pdf, logpdf using Distributions +using SpecialFunctions, LogExpFunctions import FillArrays: Ones, Eye import LoopVectorization: vmap, vmapreduce @@ -27,9 +28,11 @@ The `a` field stores the matrix parameter of the distribution. struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution a::A α0::Ts + lmnB::Ts function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} alpha0 = sum(alpha; dims = 1) - new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0) + lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) + new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB) end end @@ -112,7 +115,10 @@ function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container:: end function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N} - return sum(logpdf.(Dirichlet.(get_dirichlet_parameters(dist)), eachslice(x, dims = Tuple(2:N)))) + α = dist.a + α0 = dist.α0 + s = sum(xlogy.(α .- 1, x); dims = 1) + return sum(s .- dist.lmnB) end BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp(logpdf(dist, x)) diff --git a/test/distributions/tensor_dirichlet_test.jl b/test/distributions/tensor_dirichlet_test.jl index c362498c..ad842ceb 100644 --- a/test/distributions/tensor_dirichlet_test.jl +++ b/test/distributions/tensor_dirichlet_test.jl @@ -25,11 +25,8 @@ end end end end - - end - @testitem "TensorDirichlet: var" begin include("distributions_setuptests.jl") @@ -50,7 +47,6 @@ end end end end - end @testitem "TensorDirichlet: mean" begin @@ -73,7 +69,6 @@ end end end end - end @testitem "TensorDirichlet: std" begin @@ -96,7 +91,6 @@ end end end end - end @testitem "TensorDirichlet: cov" begin @@ -112,10 +106,10 @@ end temp = cov.(mat_of_dir) old_shape = size(alpha) - new_shape = (first(old_shape),first(old_shape),Base.tail(old_shape)...) + new_shape = (first(old_shape), first(old_shape), Base.tail(old_shape)...) mat_cov = ones(new_shape) for i in CartesianIndices(Base.tail(size(alpha))) - mat_cov[:,:, i] = temp[i] + mat_cov[:, :, i] = temp[i] end @test cov(distribution) ≈ mat_cov end @@ -258,20 +252,20 @@ end alpha2 = rand([d for _ in 1:rank]...) .+ 1 distribution1 = TensorDirichlet(alpha1) distribution2 = TensorDirichlet(alpha2) - + mat_of_dir_1 = Dirichlet.(eachslice(alpha1, dims = Tuple(2:rank))) mat_of_dir_2 = Dirichlet.(eachslice(alpha2, dims = Tuple(2:rank))) - dim = rank-1 + dim = rank - 1 - prod_temp = Array{Dirichlet,dim}(undef, Base.tail(size(alpha1))) + prod_temp = Array{Dirichlet, dim}(undef, Base.tail(size(alpha1))) for i in CartesianIndices(Base.tail(size(alpha1))) - prod_temp[i] = prod(PreserveTypeProd(Distribution),mat_of_dir_1[i],mat_of_dir_2[i]) + prod_temp[i] = prod(PreserveTypeProd(Distribution), mat_of_dir_1[i], mat_of_dir_2[i]) end mat_prod = similar(alpha1) for i in CartesianIndices(Base.tail(size(alpha1))) - mat_prod[:,i] = prod_temp[i].alpha + mat_prod[:, i] = prod_temp[i].alpha end - @test @inferred(prod(PreserveTypeProd(Distribution),distribution1,distribution2)) ≈ TensorDirichlet(mat_prod) + @test @inferred(prod(PreserveTypeProd(Distribution), distribution1, distribution2)) ≈ TensorDirichlet(mat_prod) end end end @@ -300,7 +294,6 @@ end end end end - end @testitem "TensorDirichlet: vague" begin @@ -330,11 +323,10 @@ end for rank in (3, 5) for d in (2, 5, 10) for _ in 1:10 - alpha = rand([d for _ in 1:rank]...) distribution = TensorDirichlet(alpha) (naturalParam,) = unpack_parameters(TensorDirichlet, alpha) - + mat_logPartition = sum(logPartitionDirichlet.(eachslice(alpha, dims = Tuple(2:rank)))) mat_grad = grad.(eachslice(alpha, dims = Tuple(2:rank))) mat_info = sum(info.(eachslice(alpha, dims = Tuple(2:rank)))) @@ -350,9 +342,9 @@ end @testitem "TensorDirichlet: logpdf" begin include("distributions_setuptests.jl") - for rank in (3, 5) - for d in (2, 5, 10) - for _ in 1:10 + for rank in (3, 4, 5, 6) + for d in (2, 4, 5, 10) + for i in 1:10 alpha = rand([d for _ in 1:rank]...) distribution = TensorDirichlet(alpha)