Skip to content

Commit

Permalink
Merge pull request #222 from ReactiveBayes/dirichlet-tensor-performance
Browse files Browse the repository at this point in the history
Dirichlet tensor performance
  • Loading branch information
wouterwln authored Jan 9, 2025
2 parents 159e035 + bcebba4 commit 6160272
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 60 deletions.
4 changes: 2 additions & 2 deletions benchmark/benchmarks/tensordirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ SUITE["tensordirichlet"] = BenchmarkGroup(
# `prod` BenchmarkGroup ========================
for rank in (3, 4, 5, 6)
for d in (5, 10, 20)
left = TensorDirichlet(rand([d for _ in 1:rank]...))
right = TensorDirichlet(rand([d for _ in 1:rank]...))
left = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1)
right = TensorDirichlet(rand([d for _ in 1:rank]...) .+ 1)
SUITE["tensordirichlet"]["prod"]["Closed(rank=$rank, d=$d)"] = @benchmarkable prod(ClosedProd(), $left, $right)
end
end
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ if isempty(ARGS)
export_markdown("./benchmark_logs/last.md", result)
else
name = first(ARGS)
BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05))
export_markdown("benchmark_vs_$(name)_result.md", result)
result = BenchmarkTools.judge(ExponentialFamily, name; judgekwargs = Dict(:time_tolerance => 0.1, :memory_tolerance => 0.05))
export_markdown("./benchmark_logs/benchmark_vs_$(name)_result.md", result)
end
59 changes: 23 additions & 36 deletions src/distributions/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import SpecialFunctions: digamma, loggamma
import Base: eltype
import Distributions: pdf, logpdf
using Distributions
using SpecialFunctions, LogExpFunctions

import FillArrays: Ones, Eye
import LoopVectorization: vmap, vmapreduce
Expand All @@ -24,8 +25,18 @@ The `a` field stores the matrix parameter of the distribution.
- a[:,m,n] are the parameters of a Dirichlet distribution
- a[:,m_1,n_1] and a[:,m_2,n_2] are supposed independent if (m_1,n_1) not equal to (m_2,n_2).
"""
struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}} <: ContinuousTensorDistribution
struct TensorDirichlet{T <: Real, N, A <: AbstractArray{T, N}, Ts} <: ContinuousTensorDistribution
a::A
α0::Ts
lmnB::Ts
function TensorDirichlet(alpha::AbstractArray{T, N}) where {T, N}
if !all(x -> x > zero(x), alpha)
throw(ArgumentError("All elements of the alpha tensor should be positive"))
end
alpha0 = sum(alpha; dims = 1)
lmnB = sum(loggamma, alpha; dims = 1) - loggamma.(alpha0)
new{T, N, typeof(alpha), typeof(alpha0)}(alpha, alpha0, lmnB)
end
end

get_dirichlet_parameters(dist::TensorDirichlet{T, N, A}) where {T, N, A} = eachslice(dist.a, dims = Tuple(2:N))
Expand All @@ -35,7 +46,7 @@ BayesBase.params(::MeanParametersSpace, dist::TensorDirichlet) = (reduce(vcat, e
getbasemeasure(::Type{TensorDirichlet}) = (x) -> sum([x[:, i] for i in CartesianIndices(Base.tail(size(x)))])
getsufficientstatistics(::TensorDirichlet) = (x -> vmap(log, x),)

BayesBase.mean(dist::TensorDirichlet) = dist.a ./ sum(dist.a; dims = 1)
BayesBase.mean(dist::TensorDirichlet) = dist.a ./ dist.α0
function BayesBase.cov(dist::TensorDirichlet{T}) where {T}
s = size(dist.a)
news = (first(s), first(s), Base.tail(s)...)
Expand All @@ -45,12 +56,11 @@ function BayesBase.cov(dist::TensorDirichlet{T}) where {T}
end
return v
end

function BayesBase.var(dist::TensorDirichlet)
v = similar(dist.a)
for i in CartesianIndices(Base.tail(size(dist.a)))
v[:, i] .= var(Dirichlet(dist.a[:, i]))
end
function BayesBase.var(dist::TensorDirichlet{T, N, A, Ts}) where {T, N, A, Ts}
α = dist.a
α0 = dist.α0
c = inv.(α0 .^ 2 .* (α0 .+ 1))
v = α .* (α0 .- α) .* c
return v
end
BayesBase.std(dist::TensorDirichlet) = map(d -> std(Dirichlet(d)), extract_collection(dist))
Expand Down Expand Up @@ -108,41 +118,18 @@ function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container::
end

function BayesBase.logpdf(dist::TensorDirichlet{R, N, A}, x::AbstractArray{T, N}) where {R, A, T <: Real, N}
out = zero(eltype(x))
for i in CartesianIndices(extract_collection(dist))
out += logpdf(Dirichlet(dist.a[:, i]), @view x[:, i])
end
return out
end

function _dirichlet_logpdf::AbstractVector{T}, x::AbstractVector{T}) where {T}
α0 = sum(α)
lmB = loggamma(α0) - sum(loggamma.(α))
if length(α) != length(x) || sum(x) != 1 || any(x -> x < 0, x)
return xlogy(one(eltype(α)), zero(eltype(x))) - lmB
end
s = sum(xlogy(αi - 1, xi) for (αi, xi) in zip(α, x))
return s - lmB
α = dist.a
α0 = dist.α0
s = sum(xlogy.(α .- 1, x); dims = 1)
return sum(s .- dist.lmnB)
end

BayesBase.pdf(dist::TensorDirichlet, x::Array{T, N}) where {T <: Real, N} = exp(logpdf(dist, x))

BayesBase.default_prod_rule(::Type{<:TensorDirichlet}, ::Type{<:TensorDirichlet}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::TensorDirichlet, right::TensorDirichlet)
paramL = extract_collection(left)
paramR = extract_collection(right)
Ones = ones(size(left.a))
Ones = extract_collection(TensorDirichlet(Ones))
param = copy(Ones)
for i in eachindex(paramL)
param[i] .= paramL[i] .+ paramR[i] .- Ones[i]
end
out = similar(left.a)
for i in CartesianIndices(param)
out[:, i] = param[i]
end
return TensorDirichlet(out)
return TensorDirichlet(left.a .+ right.a .- 1)
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)
Expand Down
32 changes: 12 additions & 20 deletions test/distributions/tensor_dirichlet_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ end
end
end
end


end


@testitem "TensorDirichlet: var" begin
include("distributions_setuptests.jl")

Expand All @@ -50,7 +47,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: mean" begin
Expand All @@ -73,7 +69,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: std" begin
Expand All @@ -96,7 +91,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: cov" begin
Expand All @@ -112,10 +106,10 @@ end

temp = cov.(mat_of_dir)
old_shape = size(alpha)
new_shape = (first(old_shape),first(old_shape),Base.tail(old_shape)...)
new_shape = (first(old_shape), first(old_shape), Base.tail(old_shape)...)
mat_cov = ones(new_shape)
for i in CartesianIndices(Base.tail(size(alpha)))
mat_cov[:,:, i] = temp[i]
mat_cov[:, :, i] = temp[i]
end
@test cov(distribution) mat_cov
end
Expand Down Expand Up @@ -258,20 +252,20 @@ end
alpha2 = rand([d for _ in 1:rank]...) .+ 1
distribution1 = TensorDirichlet(alpha1)
distribution2 = TensorDirichlet(alpha2)

mat_of_dir_1 = Dirichlet.(eachslice(alpha1, dims = Tuple(2:rank)))
mat_of_dir_2 = Dirichlet.(eachslice(alpha2, dims = Tuple(2:rank)))
dim = rank-1
dim = rank - 1

prod_temp = Array{Dirichlet,dim}(undef, Base.tail(size(alpha1)))
prod_temp = Array{Dirichlet, dim}(undef, Base.tail(size(alpha1)))
for i in CartesianIndices(Base.tail(size(alpha1)))
prod_temp[i] = prod(PreserveTypeProd(Distribution),mat_of_dir_1[i],mat_of_dir_2[i])
prod_temp[i] = prod(PreserveTypeProd(Distribution), mat_of_dir_1[i], mat_of_dir_2[i])
end
mat_prod = similar(alpha1)
for i in CartesianIndices(Base.tail(size(alpha1)))
mat_prod[:,i] = prod_temp[i].alpha
mat_prod[:, i] = prod_temp[i].alpha
end
@test @inferred(prod(PreserveTypeProd(Distribution),distribution1,distribution2)) TensorDirichlet(mat_prod)
@test @inferred(prod(PreserveTypeProd(Distribution), distribution1, distribution2)) TensorDirichlet(mat_prod)
end
end
end
Expand Down Expand Up @@ -300,7 +294,6 @@ end
end
end
end

end

@testitem "TensorDirichlet: vague" begin
Expand Down Expand Up @@ -330,11 +323,10 @@ end
for rank in (3, 5)
for d in (2, 5, 10)
for _ in 1:10

alpha = rand([d for _ in 1:rank]...)
distribution = TensorDirichlet(alpha)
(naturalParam,) = unpack_parameters(TensorDirichlet, alpha)

mat_logPartition = sum(logPartitionDirichlet.(eachslice(alpha, dims = Tuple(2:rank))))
mat_grad = grad.(eachslice(alpha, dims = Tuple(2:rank)))
mat_info = sum(info.(eachslice(alpha, dims = Tuple(2:rank))))
Expand All @@ -350,9 +342,9 @@ end
@testitem "TensorDirichlet: logpdf" begin
include("distributions_setuptests.jl")

for rank in (3, 5)
for d in (2, 5, 10)
for _ in 1:10
for rank in (3, 4, 5, 6)
for d in (2, 4, 5, 10)
for i in 1:10
alpha = rand([d for _ in 1:rank]...)

distribution = TensorDirichlet(alpha)
Expand Down

0 comments on commit 6160272

Please sign in to comment.