Skip to content

Commit

Permalink
Speedup logpdf
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterwln committed Jan 9, 2025
1 parent bd657cb commit 2aba23d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
10 changes: 8 additions & 2 deletions src/distributions/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import SpecialFunctions: digamma, loggamma
import Base: eltype
import Distributions: pdf, logpdf
using Distributions
using SpecialFunctions, LogExpFunctions

import FillArrays: Ones, Eye
import LoopVectorization: vmap, vmapreduce
Expand All @@ -27,9 +28,11 @@ The `a` field stores the matrix parameter of the distribution.
struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution
a::A
α0::Ts
lmnB::Ts
function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N}
alpha0 = sum(alpha; dims = 1)
new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0)
lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0)
new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB)
end
end

Expand Down Expand Up @@ -112,7 +115,10 @@ function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container::
end

function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N}
return sum(logpdf.(Dirichlet.(get_dirichlet_parameters(dist)), eachslice(x, dims = Tuple(2:N))))
α = dist.a
α0 = dist.α0
s = sum(xlogy.(α .- 1, x); dims = 1)
return sum(s .- dist.lmnB)
end

BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp(logpdf(dist, x))
Expand Down
32 changes: 12 additions & 20 deletions test/distributions/tensor_dirichlet_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ end
end
end
end


end


@testitem "TensorDirichlet: var" begin
include("distributions_setuptests.jl")

Expand All @@ -50,7 +47,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: mean" begin
Expand All @@ -73,7 +69,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: std" begin
Expand All @@ -96,7 +91,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: cov" begin
Expand All @@ -112,10 +106,10 @@ end

temp = cov.(mat_of_dir)
old_shape = size(alpha)
new_shape = (first(old_shape),first(old_shape),Base.tail(old_shape)...)
new_shape = (first(old_shape), first(old_shape), Base.tail(old_shape)...)
mat_cov = ones(new_shape)
for i in CartesianIndices(Base.tail(size(alpha)))
mat_cov[:,:, i] = temp[i]
mat_cov[:, :, i] = temp[i]
end
@test cov(distribution) mat_cov
end
Expand Down Expand Up @@ -258,20 +252,20 @@ end
alpha2 = rand([d for _ in 1:rank]...) .+ 1
distribution1 = TensorDirichlet(alpha1)
distribution2 = TensorDirichlet(alpha2)

mat_of_dir_1 = Dirichlet.(eachslice(alpha1, dims = Tuple(2:rank)))
mat_of_dir_2 = Dirichlet.(eachslice(alpha2, dims = Tuple(2:rank)))
dim = rank-1
dim = rank - 1

prod_temp = Array{Dirichlet,dim}(undef, Base.tail(size(alpha1)))
prod_temp = Array{Dirichlet, dim}(undef, Base.tail(size(alpha1)))
for i in CartesianIndices(Base.tail(size(alpha1)))
prod_temp[i] = prod(PreserveTypeProd(Distribution),mat_of_dir_1[i],mat_of_dir_2[i])
prod_temp[i] = prod(PreserveTypeProd(Distribution), mat_of_dir_1[i], mat_of_dir_2[i])
end
mat_prod = similar(alpha1)
for i in CartesianIndices(Base.tail(size(alpha1)))
mat_prod[:,i] = prod_temp[i].alpha
mat_prod[:, i] = prod_temp[i].alpha
end
@test @inferred(prod(PreserveTypeProd(Distribution),distribution1,distribution2)) TensorDirichlet(mat_prod)
@test @inferred(prod(PreserveTypeProd(Distribution), distribution1, distribution2)) TensorDirichlet(mat_prod)
end
end
end
Expand Down Expand Up @@ -300,7 +294,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: vague" begin
Expand Down Expand Up @@ -330,11 +323,10 @@ end
for rank in (3, 5)
for d in (2, 5, 10)
for _ in 1:10

alpha = rand([d for _ in 1:rank]...)
distribution = TensorDirichlet(alpha)
(naturalParam,) = unpack_parameters(TensorDirichlet, alpha)

mat_logPartition = sum(logPartitionDirichlet.(eachslice(alpha, dims = Tuple(2:rank))))
mat_grad = grad.(eachslice(alpha, dims = Tuple(2:rank)))
mat_info = sum(info.(eachslice(alpha, dims = Tuple(2:rank))))
Expand All @@ -350,9 +342,9 @@ end
@testitem "TensorDirichlet: logpdf" begin
include("distributions_setuptests.jl")

for rank in (3, 5)
for d in (2, 5, 10)
for _ in 1:10
for rank in (3, 4, 5, 6)
for d in (2, 4, 5, 10)
for i in 1:10
alpha = rand([d for _ in 1:rank]...)

distribution = TensorDirichlet(alpha)
Expand Down

0 comments on commit 2aba23d

Please sign in to comment.