Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dirichlet tensor #219

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
93047b2
testTensirDirichlet
Raphael-Tresor Nov 25, 2024
e5e08e3
tensorDirichlet not exponentialFamiliy compatible
Raphael-Tresor Dec 24, 2024
d3b5f1a
Update initialization of out
wouterwln Jan 7, 2025
f244d1c
Ensure all tests are in `testitem` blocks
wouterwln Jan 7, 2025
603d989
testTensirDirichlet
Raphael-Tresor Nov 25, 2024
bfc7466
tensorDirichlet not exponentialFamiliy compatible
Raphael-Tresor Dec 24, 2024
de2e32f
Update initialization of out
wouterwln Jan 7, 2025
6a9df05
Ensure all tests are in `testitem` blocks
wouterwln Jan 7, 2025
d8ccd98
Merge branch 'DirichletTensor' of https://github.com/ReactiveBayes/Ex…
wouterwln Jan 7, 2025
c399f37
Add TensorDirichlet benchmarks
wouterwln Jan 7, 2025
c033071
Add logpdf tests and benchmarks
wouterwln Jan 8, 2025
ff5cf02
Formatting
wouterwln Jan 8, 2025
0977bc3
Refine prod benchmarks
wouterwln Jan 8, 2025
d1bfa07
Update benchmarks
wouterwln Jan 8, 2025
baf0514
Update benchmarks
wouterwln Jan 8, 2025
3e83feb
Reimplement var
wouterwln Jan 8, 2025
b28e67e
Update logpdf tests
wouterwln Jan 8, 2025
1339467
Revise tests
wouterwln Jan 8, 2025
e38405f
Reimplement cov
wouterwln Jan 9, 2025
fd790dd
test_cov
Raphael-Tresor Jan 9, 2025
6a7b028
tests base functions
Raphael-Tresor Jan 9, 2025
6c9d756
Reimplement prod
wouterwln Jan 8, 2025
40f08e2
Update judge benchmarks
wouterwln Jan 8, 2025
78026b8
Speedup logpdf
wouterwln Jan 8, 2025
e43b376
Update benchmarks
wouterwln Jan 8, 2025
678af98
Save alpha0
wouterwln Jan 8, 2025
89dd0ba
reimplement var
wouterwln Jan 9, 2025
bd657cb
Speedup var
wouterwln Jan 9, 2025
2aba23d
Speedup logpdf
wouterwln Jan 9, 2025
2aa166d
Introduce error msg for negative alpha
wouterwln Jan 9, 2025
bcebba4
Make benchmarks robust
wouterwln Jan 9, 2025
159e035
test: unmark broken tests
Nimrais Jan 9, 2025
6160272
Merge pull request #222 from ReactiveBayes/dirichlet-tensor-performance
wouterwln Jan 9, 2025
e506a5d
Fix `std`
wouterwln Jan 9, 2025
2c8bd8c
Add entropy to benchmarks
wouterwln Jan 9, 2025
bbba975
Add out of support case for `logpdf`
wouterwln Jan 9, 2025
85d437a
Reimplement entropy
wouterwln Jan 9, 2025
f38bfe0
Merge pull request #224 from ReactiveBayes/dt_entropy
wouterwln Jan 9, 2025
f94b302
Merge branch 'main' into DirichletTensor
wouterwln Jan 9, 2025
4dcb512
Merge branch 'DirichletTensor' of https://github.com/ReactiveBayes/Ex…
wouterwln Jan 9, 2025
83da858
test rand
Raphael-Tresor Jan 9, 2025
b042beb
docs: they are always independent
Nimrais Jan 9, 2025
0e327fa
fix: mark broken methods
Nimrais Jan 10, 2025
54ef2d6
fix: proper conversion between Mean and Natural spaces
Nimrais Jan 10, 2025
19a9446
fix: implement rand! without extract_collection
Nimrais Jan 10, 2025
3e87207
fix: implement join_conditioner
Nimrais Jan 10, 2025
5791da8
fix: implement NaturalParametersSpace getlogpartition
Nimrais Jan 10, 2025
5253fbe
fix: implement getgradlogpartition for TensorDirichlet
Nimrais Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ExponentialFamily.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,6 @@ include("distributions/poisson.jl")
include("distributions/chi_squared.jl")
include("distributions/mv_normal_wishart.jl")
include("distributions/normal_gamma.jl")
include("distributions/tensor_dirichlet.jl")

end
195 changes: 195 additions & 0 deletions src/distributions/tensor_dirichlet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
export TensorDirichlet, ContinuousTensorDistribution

import SpecialFunctions: digamma, loggamma
import Base: eltype
import Distributions: pdf, logpdf
using Distributions

import SparseArrays: blockdiag, sparse
import FillArrays: Ones, Eye
import LoopVectorization: vmap, vmapreduce
using LinearAlgebra, Random

const ContinuousTensorDistribution = Distribution{ ArrayLikeVariate, Continuous}

"""
TensorDirichlet{T <: Real, A <: AbstractArrray{T,3}} <: ContinuousTensorDistribution
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not document the { } thing since its an internal detail and its also wrongly documented


A tensor-valued TensorDirichlet distribution, where `T` is the element type of the tensor `A`.
The `a` field stores the matrix parameter of the distribution.

# Fields
- `a::A`: The matrix parameter of the TensorDirichlet distribution.

# Model
- a[:,m,n] are the parameters of a Dirichlet distribution
- a[:,m_1,n_1] and a[:,m_2,n_2] are supposed independent if (m_1,n_1) not equal to (m_2,n_2).
"""
struct TensorDirichlet{T<:Real, N , A <: AbstractArray{T,N}} <: ContinuousTensorDistribution
a::A
end

extract_collection(dist::TensorDirichlet) = [dist.a[:,i] for i in CartesianIndices(Base.tail(size(dist.a)))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract_collection(dist::TensorDirichlet) = (view(dist.a, :, i) for i in CartesianIndices(Base.tail(size(dist.a))))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes with a '@view' it will be way better. I will take your suggestion.

unpack_parameters(::Type{TensorDirichlet}, packed) = ([packed[:,i] for i in CartesianIndices(Base.tail(size(packed)))],)
BayesBase.params(::MeanParametersSpace, dist::TensorDirichlet) = (reduce(vcat,extract_collection(dist)),)
getbasemeasure(::Type{TensorDirichlet}) = (x) -> sum([x[:,i] for i in CartesianIndices(Base.tail(size(x)))])
getsufficientstatistics(::TensorDirichlet) = (x -> vmap(log, x),)


BayesBase.mean(dist::TensorDirichlet) = dist.a ./ sum(dist.a; dims = 1)
BayesBase.cov(dist::TensorDirichlet) = map(d->cov(Dirichlet(d)),extract_collection(dist))
BayesBase.var(dist::TensorDirichlet) = map(d->var(Dirichlet(d)),extract_collection(dist))
BayesBase.std(dist::TensorDirichlet) = map(d->std(Dirichlet(d)),extract_collection(dist))


BayesBase.params(dist::TensorDirichlet) = (dist.a,)

Base.size(dist::TensorDirichlet) = size(dist.a)
Base.eltype(::TensorDirichlet{T}) where {T} = T

function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int)
Copy link
Member

@Nimrais Nimrais Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one of the vague methods (with a tuple), you specify the type of ones, while in another, you don't. Is there any reason why you aren't writing it like this?

function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int) 
    return TensorDirichlet(ones(Float64, dims, dims))
end

Copy link
Contributor Author

@Raphael-Tresor Raphael-Tresor Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason, I missed this difference.
Float64 seems to be the default type, so it should be the same.
I will add the Float64 argument for clarity.
Thanks

return TensorDirichlet(ones(dims,dims))
end


function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Tuple)
return TensorDirichlet(ones(Float64,dims))
end

function BayesBase.entropy(dist::TensorDirichlet)
return vmapreduce(+, extract_collection(dist)) do column
scolumn = sum(column)
-sum((column .- one(Float64)) .* (digamma.(column) .- digamma.(scolumn))) - loggamma(scolumn) +
sum(loggamma.(column))
end
end

BayesBase.promote_variate_type(::Type{Multivariate}, ::Type{<:Dirichlet}) = Dirichlet
BayesBase.promote_variate_type(::Type{Multivariate}, ::Type{<:TensorDirichlet}) = TensorDirichlet
BayesBase.promote_variate_type(::Type{ArrayLikeVariate}, ::Type{<:Dirichlet}) = TensorDirichlet

function BayesBase.rand(rng::AbstractRNG, dist::TensorDirichlet{T}) where {T}
container = similar(dist.a)
return rand!(rng, dist, container)
end

function BayesBase.rand(rng::AbstractRNG, dist::TensorDirichlet{T}, nsamples::Int64) where {T}
container = Vector{typeof(dist.a)}(undef, nsamples)
@inbounds for i in eachindex(container)
container[i] = similar(dist.a)
rand!(rng, dist, container[i])
end
return container
end

function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet, container::AbstractArray{T,N}) where {T <: Real, N }
for index in CartesianIndices(extract_collection(dist))
rand!(rng, Dirichlet(dist.a[:,index]), @view container[:,index])
end
return container
end

function BayesBase.rand!(rng::AbstractRNG, dist::TensorDirichlet{A}, container::AbstractArray{A,N}) where {T <: Real, N, A <: AbstractArray{T,N}}
for i in container
rand!(rng, dist, @view container[i])
end
return container
end

function BayesBase.logpdf(dist::TensorDirichlet, x::AbstractArray{T,N}) where {T <: Real, N}
out = 0
wouterwln marked this conversation as resolved.
Show resolved Hide resolved
for i in CartesianIndices(extract_collection(dist.a))
out =+ logpdf(Dirichlet(dist.a[:,i]), @view x[:,i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to write +=?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed

end
return out
end

BayesBase.pdf(dist::TensorDirichlet, x::Array{T,N}) where {T <: Real,N} = exp(logpdf(dist, x))

BayesBase.default_prod_rule(::Type{<:TensorDirichlet}, ::Type{<:TensorDirichlet}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::TensorDirichlet, right::TensorDirichlet)
paramL = extract_collection(left)
paramR = extract_collection(right)
Ones = ones(size(left.a))
Ones = extract_collection(TensorDirichlet(Ones))
param = copy(Ones)
for i in eachindex(paramL)
param[i] .= paramL[i] .+ paramR[i] .- Ones[i]
end
out = similar(left.a)
for i in CartesianIndices(param)
out[:,i] = param[i]
end
return TensorDirichlet(out)
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)
l = size(getnaturalparameters(ef))
values = [x[:,i] for i in CartesianIndices(Base.tail(size(x)))]
## The element of the array should be the a categorical distribution (an vector of postive value that sum to 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a comment, this is the documentation make it a docstring, smt like this

"""
BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)

Check if the input `x` is within the support of a Tensor Dirichlet distribution.

Requirements:
- Each column of x must represent a valid categorical distribution (sum to 1, all values ≥ 0)
- Dimensions must match the natural parameters of the distribution

"""

## and all catagorical distribution should have the same size than the corresponding disrichlet prior (not checked).
return l == size(x) && all(x ->sum(x) ≈ 1, values) && all(!any(x-> x < 0 ), values)
end

# Natural parametrization

isproper(::NaturalParametersSpace, ::Type{TensorDirichlet}, η, conditioner) =
isnothing(conditioner) && length(η) > 1 && all( map(x->isproper(NaturalParametersSpace(),Dirichlet,x,), unpack_parameters(TensorDirichlet, η)))
isproper(::MeanParametersSpace, ::Type{TensorDirichlet}, θ, conditioner) =
isnothing(conditioner) && length(θ) > 1 && all( map(x->isproper(MeanParametersSpace(),Dirichlet,x,),unpack_parameters(TensorDirichlet, θ)))
isproper(p, ::Type{TensorDirichlet}, η, conditioner) =
isnothing(conditioner) && all(x->isproper(p,Type{Dirichlet},x),unpack_parameters(TensorDirichlet, η))


function (::MeanToNatural{TensorDirichlet})(tuple_of_θ::Tuple{Any})
(α,) = tuple_of_θ
out = copy(α)
for i in CartesianIndices(Base.tail(size(α)))
out[:,i] = α[:,i] - ones(length(α[:,i]))
end
return out
end

function (::NaturalToMean{TensorDirichlet})(tuple_of_η::Tuple{Any})
(α,) = tuple_of_η
out = copy(α)
for i in CartesianIndices(Base.tail(size(α)))
out[:,i] = α[:,i] + ones(length(α[:,i]))
end
return out
end


getlogpartition(::NaturalParametersSpace, ::Type{TensorDirichlet}) =
(η) -> begin
return mapreduce(x->getlogpartition(NaturalParametersSpace(),Dirichlet)(x),+,η)
end

getgradlogpartition(::NaturalParametersSpace, ::Type{TensorDirichlet}) =
(η) -> begin
return map(d -> getgradlogpartition(NaturalParametersSpace(), Dirichlet)(d), η)
end

getfisherinformation(::NaturalParametersSpace, ::Type{TensorDirichlet}) =
(η) -> begin
return mapreduce(d -> getfisherinformation(NaturalParametersSpace(), Dirichlet)(d),+, η)
end

# Mean parametrization

getlogpartition(::MeanParametersSpace, ::Type{TensorDirichlet}) =
(η) -> begin
return mapreduce(x->getlogpartition(MeanParametersSpace(),Dirichlet)(x),+,η)
end

getgradlogpartition(::MeanParametersSpace, ::Type{TensorDirichlet}) =
(η) -> begin
return map(d -> getgradlogpartition(MeanParametersSpace(), Dirichlet)(d), η)
end

getfisherinformation(::MeanParametersSpace, ::Type{TensorDirichlet}) =
(η) -> begin
return mapreduce(d -> getfisherinformation(MeanParametersSpace(), Dirichlet)(d),+, η)
end

Loading
Loading