Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dirichlet tensor #219

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from
Draft

Dirichlet tensor #219

wants to merge 48 commits into from

Conversation

Raphael-Tresor
Copy link
Contributor

@Raphael-Tresor Raphael-Tresor commented Dec 24, 2024

Issue

The branch is not ExponentialFamily compatible yet. The pack unpack mechanism does not work.

const ContinuousTensorDistribution = Distribution{ ArrayLikeVariate, Continuous}

"""
TensorDirichlet{T <: Real, A <: AbstractArrray{T,3}} <: ContinuousTensorDistribution
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not document the { } thing since its an internal detail and its also wrongly documented

function BayesBase.logpdf(dist::TensorDirichlet, x::AbstractArray{T,N}) where {T <: Real, N}
out = 0
for i in CartesianIndices(extract_collection(dist.a))
out =+ logpdf(Dirichlet(dist.a[:,i]), @view x[:,i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to write +=?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed

varDiri[:,2,1] = var(Dirichlet(c))
varDiri[:,2,2] = var(Dirichlet(d))

@test var(TensorDirichlet(tensorDiri)) == varDiri
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats quite a strange interface for var

@test cov(TensorDirichlet(tensorDiri)) == covTensorDiri
end

@test_broken "TensorDirichlet: ExponentialFamilyDistribution" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not write tests outside of @testitem

@test @inferred(prod(PreserveTypeProd(Distribution), d2, d3)) ≈ TensorDirichlet([1.2000000000000002 4.0 2.0 ; 3.3 5.0 1.1])
end

@test_broken "TensorDirichlet: prod with ExponentialFamilyDistribution" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not write tests outside of @testitem

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite skeptical about this PR. First of all, is this distribution is from exponential family? If not, why should it belong to this package?

Second, why to do all these shenanigans with a[:, m, n] when you could simply to an array of Dirichlet distributions?

E.g.

a = [
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ]),
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ])
]

and then later on simply

a[1, 2] # returns Dirichlet 

I'm also not sure if the functionality is really correct, the logpdf can't be correct right? Why tests didn't pick it up?

@wouterwln
Copy link
Member

@bvdmitri This is a generalization of MatrixDirichlet (which is not really a Matrix-Dirichlet distribution but rather a collection of Dirichlets stored as a matrix). This distribution is the conjugate prior for the parameters of the general Transition node, which we need to do POMDPs. It is just a generalization of MatrixDirichlet, so I would say it is in the Exponential Family.

Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
@Raphael-Tresor
Copy link
Contributor Author

Raphael-Tresor commented Jan 7, 2025

I'm quite skeptical about this PR. First of all, is this distribution is from exponential family? If not, why should it belong to this package?

Second, why to do all these shenanigans with a[:, m, n] when you could simply to an array of Dirichlet distributions?

E.g.

a = [
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ]),
    Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ])
]

and then later on simply

a[1, 2] # returns Dirichlet 

I'm also not sure if the functionality is really correct, the logpdf can't be correct right? Why tests didn't pick it up?

Thank yo for the review @bvdmitri.

I discussed this with @Nimrais, and I think it is indeed a member of the exponential family.
About the implementation: creating an array of Dirichlets was my first idea, but after a discussion with @wouterwln I changed my mind. It seems that an array of Dirichlets might be slow and working directly with Float-array should be more efficient. I will think about it again.

@bvdmitri
Copy link
Member

bvdmitri commented Jan 7, 2025

@Raphael-Tresor

It seems that an array of Dirichlets might be slow and working directly with Float-array should be more efficient. I will think about it again.

I agree that working directly with Float-array, but only if you implement it correctly. I doubt that the current implementation is faster or more efficient because it creates a lot of slices and allocates new memory on every access, which is perhaps even slower than using arrays of Dirichlet distributions. You sometimes use @view but not everywhere. I don't have any benchmarks, it just more like a feeling. This being said, I don't really have a strong opinion on how to implement it and leave the choice to you.

@wouterwln
Copy link
Member

wouterwln commented Jan 7, 2025

@bvdmitri I'll work on performance, for now we just need a correct generalization of MatrixDirichlet for higher order tensors, such that we can parameterize Transition and TransitionMixture nodes.

@Nimrais
Copy link
Member

Nimrais commented Jan 7, 2025

@bvdmitri I'll work on performance, for now we just need a correct generalization of MatrixDirichlet for higher order tensors, such that we can parameterize Transition and TransitionMixture nodes.

If the performance is not a goal at this point, why not make an array of Dirichlet distributions? It's far easier to understand at least.

function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)
l = size(getnaturalparameters(ef))
values = [x[:,i] for i in CartesianIndices(Base.tail(size(x)))]
## The element of the array should be the a categorical distribution (an vector of postive value that sum to 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a comment, this is the documentation make it a docstring, smt like this

"""
BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)

Check if the input `x` is within the support of a Tensor Dirichlet distribution.

Requirements:
- Each column of x must represent a valid categorical distribution (sum to 1, all values ≥ 0)
- Dimensions must match the natural parameters of the distribution

"""

Base.size(dist::TensorDirichlet) = size(dist.a)
Base.eltype(::TensorDirichlet{T}) where {T} = T

function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int)
Copy link
Member

@Nimrais Nimrais Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one of the vague methods (with a tuple), you specify the type of ones, while in another, you don't. Is there any reason why you aren't writing it like this?

function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int) 
    return TensorDirichlet(ones(Float64, dims, dims))
end

Copy link
Contributor Author

@Raphael-Tresor Raphael-Tresor Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason, I missed this difference.
Float64 seems to be the default type, so it should be the same.
I will add the Float64 argument for clarity.
Thanks

a::A
end

extract_collection(dist::TensorDirichlet) = [dist.a[:,i] for i in CartesianIndices(Base.tail(size(dist.a)))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract_collection(dist::TensorDirichlet) = (view(dist.a, :, i) for i in CartesianIndices(Base.tail(size(dist.a))))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes with a '@view' it will be way better. I will take your suggestion.

@wouterwln
Copy link
Member

@Nimrais Because in the end we do need a performant implementation of this, which would require the data to be stored the way Raphael coded it. It is a lot easier to implement performant versions of methods if the underlying data structure does not change anymore.

@Nimrais
Copy link
Member

Nimrais commented Jan 7, 2025

Because in the end we do need a performant implementation of this, which would require the data to be stored the way Raphael coded it. It is a lot easier to implement performant versions of methods if the underlying data structure does not change anymore.

Sure, but then it would be nice to see the proof that it works faster. (@Raphael-Tresor it's more or less comment for you for the next time)

For example, a test that shows that the current way of implementing things is better than the naïve. Because if, at some point in time, it is starting to become slower, what is the point?

For example, smt like this, but written carefully

using ExponentialFamily
using BenchmarkTools
using Random
using Distributions
using Test

struct ArrayDirichlet{T}
    distributions::Array{Dirichlet{T}, 2}
end

function array_dirichlet(params::Array{T, 3}) where T
    m, n, k = size(params)
    dists = Array{Dirichlet{T}, 2}(undef, n, k)
    for i in 1:n, j in 1:k
        dists[i,j] = Dirichlet(params[:,i,j])
    end
    return ArrayDirichlet(dists)
end

struct TensorDirichlet{T}
    a::Array{T, 3}
end

dim = 10
grid_size = 20 
params = rand(dim, grid_size, grid_size) .+ 0.1

tensor_impl = TensorDirichlet(params)
array_impl = array_dirichlet(params)

test_points = [rand(dim, grid_size, grid_size) for _ in 1:100]
for x in test_points
    x ./= sum(x, dims=1)
end

tensor_bench = @benchmark for x in $test_points
    for i in 1:grid_size, j in 1:grid_size
        logpdf(Dirichlet(view($tensor_impl.a,:,i,j)), view(x,:,i,j))
    end
end

array_bench = @benchmark for x in $test_points
    for i in 1:grid_size, j in 1:grid_size
        logpdf($array_impl.distributions[i,j], view(x,:,i,j))
    end
end

println("Tensor Implementation:")
println("---------------------")
show(stdout, MIME("text/plain"), tensor_bench)
println("\n\nArray Implementation:")
println("--------------------")
show(stdout, MIME("text/plain"), array_bench)
@test min(array_bench.times...) > min(tensor_bench.times...)

@wouterwln
Copy link
Member

@Nimrais you mean something like this?

using ExponentialFamily
using BenchmarkTools
using Random
using Distributions
using Test
using BayesBase

struct ArrayDirichlet{T,S,P, N}
    distributions::AbstractArray{Dirichlet{T, S, P}, N}
end

function array_dirichlet(params::Array{T, N}) where {T, N}
    d, k... = size(params)
    return ArrayDirichlet(Dirichlet.(eachslice(params;dims=Tuple(2:N))))
end
function BayesBase.mean(d::ArrayDirichlet) where {T, N}
    return mean.(d.distributions)
end
dim = 10
grid_size = 20 
params = rand(dim, grid_size, grid_size) .+ 0.1

tensor_impl = TensorDirichlet(params)
array_impl = array_dirichlet(params)

tensor_bench = @benchmark mean($tensor_impl)

array_bench = @benchmark mean($array_impl)

println("Tensor Implementation:")
println("---------------------")
show(stdout, MIME("text/plain"), tensor_bench)
println("\n\nArray Implementation:")
println("--------------------")
show(stdout, MIME("text/plain"), array_bench)
@test min(array_bench.times...) > min(tensor_bench.times...)
Tensor Implementation:
---------------------
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  2.380 μs … 296.157 μs  ┊ GC (min … max):  0.00% … 97.94%
 Time  (median):     4.148 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   5.627 μs ±  19.098 μs  ┊ GC (mean ± σ):  27.80% ±  8.09%

                   ▃▆██▆▃                                      
  ▄▇▅▄▄▂▂▁▁▁▁▁▁▂▃▅▇███████▆▅▄▃▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  2.38 μs         Histogram: frequency by time        7.62 μs <

 Memory estimate: 34.69 KiB, allocs estimate: 3.

Array Implementation:
--------------------
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.500 μs …  1.076 ms  ┊ GC (min … max): 0.00% … 75.59%
 Time  (median):     10.500 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   12.139 μs ± 24.553 μs  ┊ GC (mean ± σ):  4.44% ±  3.29%

  ▃▇██▆▄▁                                                     ▂
  ███████▇▆██████▇▆▅▆▄▅▅▃▄▄▄▄▄▄▄▁▄▃▅▄▄▅▃▅▄▅▃▄▁▄▃▄▄▄▅▁▅▃▄▃▃▃▄▄ █
  9.5 μs       Histogram: log(frequency) by time      28.9 μs <

 Memory estimate: 59.66 KiB, allocs estimate: 403.
Test Passed

@Nimrais
Copy link
Member

Nimrais commented Jan 7, 2025

@Nimrais you mean something like this

Yeah, but we have more than only the mean method; I did an example of this test for the logpdf you did for the mean method.

In my example, because the logpdf "method" for TensorDirechlet is not optimized, it takes longer than the naïve implementation.

min(array_bench.times...) > min(tensor_bench.times...) is not passing in my example.

Test Failed at /Users/mykola/repos/biaslab/ExponentialFamily.jl/test_2.jl:54
  Expression: min(array_bench.times...) > min(tensor_bench.times...)
   Evaluated: 1.2759209e7 > 3.3939458e7

@Nimrais
Copy link
Member

Nimrais commented Jan 10, 2025

I am currently working on the PR. It is not ready for review, so I will convert it to a draft.

@Nimrais Nimrais marked this pull request as draft January 10, 2025 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants