From 6c9d756fd1c6203843efc6e73bb08bf26c9d1a55 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 8 Jan 2025 13:09:27 +0100 Subject: [PATCH 01/10] Reimplement prod --- src/distributions/tensor_dirichlet.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index b312db45..bc6e4314 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -130,19 +130,7 @@ BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp( BayesBase.default_prod_rule(::Type{<:TensorDirichlet}, ::Type{<:TensorDirichlet}) = PreserveTypeProd(Distribution) function BayesBase.prod(::PreserveTypeProd{Distribution}, left::TensorDirichlet, right::TensorDirichlet) - paramL = extract_collection(left) - paramR = extract_collection(right) - Ones = ones(size(left.a)) - Ones = extract_collection(TensorDirichlet(Ones)) - param = copy(Ones) - for i in eachindex(paramL) - param[i] .= paramL[i] .+ paramR[i] .- Ones[i] - end - out = similar(left.a) - for i in CartesianIndices(param) - out[:, i] = param[i] - end - return TensorDirichlet(out) + return TensorDirichlet(left.a .+ right.a .- 1) end function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x) From 40f08e278855ec6152a623b044ddd0147b5a9732 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 8 Jan 2025 13:13:59 +0100 Subject: [PATCH 02/10] Update judge benchmarks --- scripts/benchmark.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/benchmark.jl b/scripts/benchmark.jl index 4ef9c9f9..780b1c5a 100644 --- a/scripts/benchmark.jl +++ b/scripts/benchmark.jl @@ -10,6 +10,6 @@ if isempty(ARGS) export_markdown("./benchmark_logs/last.md", result) else name = first(ARGS) - BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) + result = BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) export_markdown("benchmark_vs_$(name)_result.md", result) end From 78026b82763acfbe5d65e5311e5a60ba9cb720a8 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 8 Jan 2025 13:14:18 +0100 Subject: [PATCH 03/10] Speedup logpdf --- src/distributions/tensor_dirichlet.jl | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index bc6e4314..bb0bb75e 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -108,21 +108,7 @@ function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container:: end function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N} - out = zero(eltype(x)) - for i in CartesianIndices(extract_collection(dist)) - out += logpdf(Dirichlet(dist.a[:, i]), @view x[:, i]) - end - return out -end - -function _dirichlet_logpdf(α::AbstractVector{T}, x::AbstractVector{T}) where {T} - α0 = sum(α) - lmB = loggamma(α0) - sum(loggamma.(α)) - if length(α) != length(x) || sum(x) != 1 || any(x -> x < 0, x) - return xlogy(one(eltype(α)), zero(eltype(x))) - lmB - end - s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(α, x)) - return s - lmB + return sum(logpdf.(Dirichlet.(get_dirichlet_parameters(dist)), eachslice(x, dims = Tuple(2:N)))) end BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp(logpdf(dist, x)) From e43b376ab32c63c7179102b75f73d86fd2c47850 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 8 Jan 2025 13:21:33 +0100 Subject: [PATCH 04/10] Update benchmarks --- scripts/benchmark.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/benchmark.jl b/scripts/benchmark.jl index 780b1c5a..1009bcc1 100644 --- a/scripts/benchmark.jl +++ b/scripts/benchmark.jl @@ -11,5 +11,5 @@ if isempty(ARGS) else name = first(ARGS) result = BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) - export_markdown("benchmark_vs_$(name)_result.md", result) + export_markdown("./benchmark_logs/benchmark_vs_$(name)_result.md", result) end From 678af98741667361dbb2b25dc24972bc2e84f353 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Wed, 8 Jan 2025 15:58:00 +0100 Subject: [PATCH 05/10] Save alpha0 --- src/distributions/tensor_dirichlet.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index bb0bb75e..cf0229a4 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -24,8 +24,13 @@ The `a` field stores the matrix parameter of the distribution. - a[:,m,n] are the parameters of a Dirichlet distribution - a[:,m_1,n_1] and a[:,m_2,n_2] are supposed independent if (m_1,n_1) not equal to (m_2,n_2). """ -struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}} <: ContinuousTensorDistribution +struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution a::A + α0::Ts + function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} + alpha0 = sum(alpha; dims = 1) + new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0) + end end get_dirichlet_parameters(dist::TensorDirichlet{T, N, A}) where {T, N, A} = eachslice(dist.a, dims = Tuple(2:N)) From 89dd0ba85dfee158da2fd2e1cd5e97004efee9ac Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 11:57:05 +0100 Subject: [PATCH 06/10] reimplement var --- src/distributions/tensor_dirichlet.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index cf0229a4..0c1408bc 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -40,7 +40,7 @@ BayesBase.params(::MeanParametersSpace, dist::TensorDirichlet) = (reduce(vcat, e getbasemeasure(::Type{TensorDirichlet}) = (x) -> sum([x[:, i] for i in CartesianIndices(Base.tail(size(x)))]) getsufficientstatistics(::TensorDirichlet) = (x -> vmap(log, x),) -BayesBase.mean(dist::TensorDirichlet) = dist.a ./ sum(dist.a; dims = 1) +BayesBase.mean(dist::TensorDirichlet) = dist.a ./ dist.α0 function BayesBase.cov(dist::TensorDirichlet{T}) where {T} s = size(dist.a) news = (first(s), first(s), Base.tail(s)...) @@ -50,11 +50,13 @@ function BayesBase.cov(dist::TensorDirichlet{T}) where {T} end return v end - -function BayesBase.var(dist::TensorDirichlet) +function BayesBase.var(dist::TensorDirichlet{T, N, A, Ts}) where {T, N, A, Ts} v = similar(dist.a) - for i in CartesianIndices(Base.tail(size(dist.a))) - v[:, i] .= var(Dirichlet(dist.a[:, i])) + for (vel, α, α0) in zip(eachslice(v, dims = Tuple(2:N)), get_dirichlet_parameters(dist), dist.α0) + c = inv(α0^2 * (α0 + 1)) + for (i, _) in enumerate(vel) + vel[i] = α[i] * (α0 - α[i]) * c + end end return v end From bd657cbc478c8d1d9c2e02db61ab08f94441e2a4 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 12:02:46 +0100 Subject: [PATCH 07/10] Speedup var --- src/distributions/tensor_dirichlet.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index 0c1408bc..5c039be2 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -51,13 +51,10 @@ function BayesBase.cov(dist::TensorDirichlet{T}) where {T} return v end function BayesBase.var(dist::TensorDirichlet{T, N, A, Ts}) where {T, N, A, Ts} - v = similar(dist.a) - for (vel, α, α0) in zip(eachslice(v, dims = Tuple(2:N)), get_dirichlet_parameters(dist), dist.α0) - c = inv(α0^2 * (α0 + 1)) - for (i, _) in enumerate(vel) - vel[i] = α[i] * (α0 - α[i]) * c - end - end + α = dist.a + α0 = dist.α0 + c = inv.(α0 .^ 2 .* (α0 .+ 1)) + v = α .* (α0 .- α) .* c return v end BayesBase.std(dist::TensorDirichlet) = map(d -> std(Dirichlet(d)), extract_collection(dist)) From 2aba23dd91cdbdcf87a2305dc5cde3e5d057be55 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 13:25:30 +0100 Subject: [PATCH 08/10] Speedup logpdf --- src/distributions/tensor_dirichlet.jl | 10 +++++-- test/distributions/tensor_dirichlet_test.jl | 32 ++++++++------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index 5c039be2..f9efd689 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -4,6 +4,7 @@ import SpecialFunctions: digamma, loggamma import Base: eltype import Distributions: pdf, logpdf using Distributions +using SpecialFunctions, LogExpFunctions import FillArrays: Ones, Eye import LoopVectorization: vmap, vmapreduce @@ -27,9 +28,11 @@ The `a` field stores the matrix parameter of the distribution. struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution a::A α0::Ts + lmnB::Ts function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} alpha0 = sum(alpha; dims = 1) - new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0) + lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) + new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB) end end @@ -112,7 +115,10 @@ function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container:: end function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N} - return sum(logpdf.(Dirichlet.(get_dirichlet_parameters(dist)), eachslice(x, dims = Tuple(2:N)))) + α = dist.a + α0 = dist.α0 + s = sum(xlogy.(α .- 1, x); dims = 1) + return sum(s .- dist.lmnB) end BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp(logpdf(dist, x)) diff --git a/test/distributions/tensor_dirichlet_test.jl b/test/distributions/tensor_dirichlet_test.jl index c362498c..ad842ceb 100644 --- a/test/distributions/tensor_dirichlet_test.jl +++ b/test/distributions/tensor_dirichlet_test.jl @@ -25,11 +25,8 @@ end end end end - - end - @testitem "TensorDirichlet: var" begin include("distributions_setuptests.jl") @@ -50,7 +47,6 @@ end end end end - end @testitem "TensorDirichlet: mean" begin @@ -73,7 +69,6 @@ end end end end - end @testitem "TensorDirichlet: std" begin @@ -96,7 +91,6 @@ end end end end - end @testitem "TensorDirichlet: cov" begin @@ -112,10 +106,10 @@ end temp = cov.(mat_of_dir) old_shape = size(alpha) - new_shape = (first(old_shape),first(old_shape),Base.tail(old_shape)...) + new_shape = (first(old_shape), first(old_shape), Base.tail(old_shape)...) mat_cov = ones(new_shape) for i in CartesianIndices(Base.tail(size(alpha))) - mat_cov[:,:, i] = temp[i] + mat_cov[:, :, i] = temp[i] end @test cov(distribution) ≈ mat_cov end @@ -258,20 +252,20 @@ end alpha2 = rand([d for _ in 1:rank]...) .+ 1 distribution1 = TensorDirichlet(alpha1) distribution2 = TensorDirichlet(alpha2) - + mat_of_dir_1 = Dirichlet.(eachslice(alpha1, dims = Tuple(2:rank))) mat_of_dir_2 = Dirichlet.(eachslice(alpha2, dims = Tuple(2:rank))) - dim = rank-1 + dim = rank - 1 - prod_temp = Array{Dirichlet,dim}(undef, Base.tail(size(alpha1))) + prod_temp = Array{Dirichlet, dim}(undef, Base.tail(size(alpha1))) for i in CartesianIndices(Base.tail(size(alpha1))) - prod_temp[i] = prod(PreserveTypeProd(Distribution),mat_of_dir_1[i],mat_of_dir_2[i]) + prod_temp[i] = prod(PreserveTypeProd(Distribution), mat_of_dir_1[i], mat_of_dir_2[i]) end mat_prod = similar(alpha1) for i in CartesianIndices(Base.tail(size(alpha1))) - mat_prod[:,i] = prod_temp[i].alpha + mat_prod[:, i] = prod_temp[i].alpha end - @test @inferred(prod(PreserveTypeProd(Distribution),distribution1,distribution2)) ≈ TensorDirichlet(mat_prod) + @test @inferred(prod(PreserveTypeProd(Distribution), distribution1, distribution2)) ≈ TensorDirichlet(mat_prod) end end end @@ -300,7 +294,6 @@ end end end end - end @testitem "TensorDirichlet: vague" begin @@ -330,11 +323,10 @@ end for rank in (3, 5) for d in (2, 5, 10) for _ in 1:10 - alpha = rand([d for _ in 1:rank]...) distribution = TensorDirichlet(alpha) (naturalParam,) = unpack_parameters(TensorDirichlet, alpha) - + mat_logPartition = sum(logPartitionDirichlet.(eachslice(alpha, dims = Tuple(2:rank)))) mat_grad = grad.(eachslice(alpha, dims = Tuple(2:rank))) mat_info = sum(info.(eachslice(alpha, dims = Tuple(2:rank)))) @@ -350,9 +342,9 @@ end @testitem "TensorDirichlet: logpdf" begin include("distributions_setuptests.jl") - for rank in (3, 5) - for d in (2, 5, 10) - for _ in 1:10 + for rank in (3, 4, 5, 6) + for d in (2, 4, 5, 10) + for i in 1:10 alpha = rand([d for _ in 1:rank]...) distribution = TensorDirichlet(alpha) From 2aa166dada0303269ab4fd707dcc85c80d9cd5c2 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 13:36:42 +0100 Subject: [PATCH 09/10] Introduce error msg for negative alpha --- src/distributions/tensor_dirichlet.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index f9efd689..3440d890 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -30,6 +30,9 @@ struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: Continuous α0::Ts lmnB::Ts function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} + if !all(x -> x > zero(x), alpha) + throw(ArgumentError("All elements of the alpha tensor should be positive")) + end alpha0 = sum(alpha; dims = 1) lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB) From bcebba478c15b684364438e2d60ed617ead1d2cb Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Thu, 9 Jan 2025 13:36:50 +0100 Subject: [PATCH 10/10] Make benchmarks robust --- benchmark/benchmarks/tensordirichlet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/benchmarks/tensordirichlet.jl b/benchmark/benchmarks/tensordirichlet.jl index 855fcb7a..eaaa58ae 100644 --- a/benchmark/benchmarks/tensordirichlet.jl +++ b/benchmark/benchmarks/tensordirichlet.jl @@ -8,8 +8,8 @@ SUITE["tensordirichlet"] = BenchmarkGroup( # `prod` BenchmarkGroup ======================== for rank in (3, 4, 5, 6) for d in (5, 10, 20) - left = TensorDirichlet(rand([d for _ in 1:rank]...)) - right = TensorDirichlet(rand([d for _ in 1:rank]...)) + left = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1) + right = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1) SUITE["tensordirichlet"]["prod"]["Closed(rank=$rank, d=$d)"] = @benchmarkable prod(ClosedProd(), $left, $right) end end