diff --git a/benchmark/benchmarks/tensordirichlet.jl b/benchmark/benchmarks/tensordirichlet.jl index 855fcb7a..eaaa58ae 100644 --- a/benchmark/benchmarks/tensordirichlet.jl +++ b/benchmark/benchmarks/tensordirichlet.jl @@ -8,8 +8,8 @@ SUITE["tensordirichlet"] = BenchmarkGroup( # `prod` BenchmarkGroup ======================== for rank in (3, 4, 5, 6) for d in (5, 10, 20) - left = TensorDirichlet(rand([d for _ in 1:rank]...)) - right = TensorDirichlet(rand([d for _ in 1:rank]...)) + left = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1) + right = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1) SUITE["tensordirichlet"]["prod"]["Closed(rank=$rank, d=$d)"] = @benchmarkable prod(ClosedProd(), $left, $right) end end diff --git a/scripts/benchmark.jl b/scripts/benchmark.jl index 4ef9c9f9..1009bcc1 100644 --- a/scripts/benchmark.jl +++ b/scripts/benchmark.jl @@ -10,6 +10,6 @@ if isempty(ARGS) export_markdown("./benchmark_logs/last.md", result) else name = first(ARGS) - BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) - export_markdown("benchmark_vs_$(name)_result.md", result) + result = BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05)) + export_markdown("./benchmark_logs/benchmark_vs_$(name)_result.md", result) end diff --git a/src/distributions/tensor_dirichlet.jl b/src/distributions/tensor_dirichlet.jl index b312db45..3440d890 100644 --- a/src/distributions/tensor_dirichlet.jl +++ b/src/distributions/tensor_dirichlet.jl @@ -4,6 +4,7 @@ import SpecialFunctions: digamma, loggamma import Base: eltype import Distributions: pdf, logpdf using Distributions +using SpecialFunctions, LogExpFunctions import FillArrays: Ones, Eye import LoopVectorization: vmap, vmapreduce @@ -24,8 +25,18 @@ The `a` field stores the matrix parameter of the distribution. - a[:,m,n] are the parameters of a Dirichlet distribution - a[:,m_1,n_1] and a[:,m_2,n_2] are supposed independent if (m_1,n_1) not equal to (m_2,n_2). """ -struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}} <: ContinuousTensorDistribution +struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution a::A + α0::Ts + lmnB::Ts + function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N} + if !all(x -> x > zero(x), alpha) + throw(ArgumentError("All elements of the alpha tensor should be positive")) + end + alpha0 = sum(alpha; dims = 1) + lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0) + new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB) + end end get_dirichlet_parameters(dist::TensorDirichlet{T, N, A}) where {T, N, A} = eachslice(dist.a, dims = Tuple(2:N)) @@ -35,7 +46,7 @@ BayesBase.params(::MeanParametersSpace, dist::TensorDirichlet) = (reduce(vcat, e getbasemeasure(::Type{TensorDirichlet}) = (x) -> sum([x[:, i] for i in CartesianIndices(Base.tail(size(x)))]) getsufficientstatistics(::TensorDirichlet) = (x -> vmap(log, x),) -BayesBase.mean(dist::TensorDirichlet) = dist.a ./ sum(dist.a; dims = 1) +BayesBase.mean(dist::TensorDirichlet) = dist.a ./ dist.α0 function BayesBase.cov(dist::TensorDirichlet{T}) where {T} s = size(dist.a) news = (first(s), first(s), Base.tail(s)...) @@ -45,12 +56,11 @@ function BayesBase.cov(dist::TensorDirichlet{T}) where {T} end return v end - -function BayesBase.var(dist::TensorDirichlet) - v = similar(dist.a) - for i in CartesianIndices(Base.tail(size(dist.a))) - v[:, i] .= var(Dirichlet(dist.a[:, i])) - end +function BayesBase.var(dist::TensorDirichlet{T, N, A, Ts}) where {T, N, A, Ts} + α = dist.a + α0 = dist.α0 + c = inv.(α0 .^ 2 .* (α0 .+ 1)) + v = α .* (α0 .- α) .* c return v end BayesBase.std(dist::TensorDirichlet) = map(d -> std(Dirichlet(d)), extract_collection(dist)) @@ -108,21 +118,10 @@ function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container:: end function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N} - out = zero(eltype(x)) - for i in CartesianIndices(extract_collection(dist)) - out += logpdf(Dirichlet(dist.a[:, i]), @view x[:, i]) - end - return out -end - -function _dirichlet_logpdf(α::AbstractVector{T}, x::AbstractVector{T}) where {T} - α0 = sum(α) - lmB = loggamma(α0) - sum(loggamma.(α)) - if length(α) != length(x) || sum(x) != 1 || any(x -> x < 0, x) - return xlogy(one(eltype(α)), zero(eltype(x))) - lmB - end - s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(α, x)) - return s - lmB + α = dist.a + α0 = dist.α0 + s = sum(xlogy.(α .- 1, x); dims = 1) + return sum(s .- dist.lmnB) end BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp(logpdf(dist, x)) @@ -130,19 +129,7 @@ BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp( BayesBase.default_prod_rule(::Type{<:TensorDirichlet}, ::Type{<:TensorDirichlet}) = PreserveTypeProd(Distribution) function BayesBase.prod(::PreserveTypeProd{Distribution}, left::TensorDirichlet, right::TensorDirichlet) - paramL = extract_collection(left) - paramR = extract_collection(right) - Ones = ones(size(left.a)) - Ones = extract_collection(TensorDirichlet(Ones)) - param = copy(Ones) - for i in eachindex(paramL) - param[i] .= paramL[i] .+ paramR[i] .- Ones[i] - end - out = similar(left.a) - for i in CartesianIndices(param) - out[:, i] = param[i] - end - return TensorDirichlet(out) + return TensorDirichlet(left.a .+ right.a .- 1) end function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x) diff --git a/test/distributions/tensor_dirichlet_test.jl b/test/distributions/tensor_dirichlet_test.jl index 2775fd16..7e99f6ab 100644 --- a/test/distributions/tensor_dirichlet_test.jl +++ b/test/distributions/tensor_dirichlet_test.jl @@ -25,11 +25,8 @@ end end end end - - end - @testitem "TensorDirichlet: var" begin include("distributions_setuptests.jl") @@ -50,7 +47,6 @@ end end end end - end @testitem "TensorDirichlet: mean" begin @@ -73,7 +69,6 @@ end end end end - end @testitem "TensorDirichlet: std" begin @@ -96,7 +91,6 @@ end end end end - end @testitem "TensorDirichlet: cov" begin @@ -112,10 +106,10 @@ end temp = cov.(mat_of_dir) old_shape = size(alpha) - new_shape = (first(old_shape),first(old_shape),Base.tail(old_shape)...) + new_shape = (first(old_shape), first(old_shape), Base.tail(old_shape)...) mat_cov = ones(new_shape) for i in CartesianIndices(Base.tail(size(alpha))) - mat_cov[:,:, i] = temp[i] + mat_cov[:, :, i] = temp[i] end @test cov(distribution) ≈ mat_cov end @@ -258,20 +252,20 @@ end alpha2 = rand([d for _ in 1:rank]...) .+ 1 distribution1 = TensorDirichlet(alpha1) distribution2 = TensorDirichlet(alpha2) - + mat_of_dir_1 = Dirichlet.(eachslice(alpha1, dims = Tuple(2:rank))) mat_of_dir_2 = Dirichlet.(eachslice(alpha2, dims = Tuple(2:rank))) - dim = rank-1 + dim = rank - 1 - prod_temp = Array{Dirichlet,dim}(undef, Base.tail(size(alpha1))) + prod_temp = Array{Dirichlet, dim}(undef, Base.tail(size(alpha1))) for i in CartesianIndices(Base.tail(size(alpha1))) - prod_temp[i] = prod(PreserveTypeProd(Distribution),mat_of_dir_1[i],mat_of_dir_2[i]) + prod_temp[i] = prod(PreserveTypeProd(Distribution), mat_of_dir_1[i], mat_of_dir_2[i]) end mat_prod = similar(alpha1) for i in CartesianIndices(Base.tail(size(alpha1))) - mat_prod[:,i] = prod_temp[i].alpha + mat_prod[:, i] = prod_temp[i].alpha end - @test @inferred(prod(PreserveTypeProd(Distribution),distribution1,distribution2)) ≈ TensorDirichlet(mat_prod) + @test @inferred(prod(PreserveTypeProd(Distribution), distribution1, distribution2)) ≈ TensorDirichlet(mat_prod) end end end @@ -300,7 +294,6 @@ end end end end - end @testitem "TensorDirichlet: vague" begin @@ -330,11 +323,10 @@ end for rank in (3, 5) for d in (2, 5, 10) for _ in 1:10 - alpha = rand([d for _ in 1:rank]...) distribution = TensorDirichlet(alpha) (naturalParam,) = unpack_parameters(TensorDirichlet, alpha) - + mat_logPartition = sum(logPartitionDirichlet.(eachslice(alpha, dims = Tuple(2:rank)))) mat_grad = grad.(eachslice(alpha, dims = Tuple(2:rank))) mat_info = sum(info.(eachslice(alpha, dims = Tuple(2:rank)))) @@ -350,9 +342,9 @@ end @testitem "TensorDirichlet: logpdf" begin include("distributions_setuptests.jl") - for rank in (3, 5) - for d in (2, 5, 10) - for _ in 1:10 + for rank in (3, 4, 5, 6) + for d in (2, 4, 5, 10) + for i in 1:10 alpha = rand([d for _ in 1:rank]...) distribution = TensorDirichlet(alpha)