-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dirichlet tensor #219
base: main
Are you sure you want to change the base?
Dirichlet tensor #219
Conversation
const ContinuousTensorDistribution = Distribution{ ArrayLikeVariate, Continuous} | ||
|
||
""" | ||
TensorDirichlet{T <: Real, A <: AbstractArrray{T,3}} <: ContinuousTensorDistribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not document the { }
thing since its an internal detail and its also wrongly documented
function BayesBase.logpdf(dist::TensorDirichlet, x::AbstractArray{T,N}) where {T <: Real, N} | ||
out = 0 | ||
for i in CartesianIndices(extract_collection(dist.a)) | ||
out =+ logpdf(Dirichlet(dist.a[:,i]), @view x[:,i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to write +=
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes indeed
varDiri[:,2,1] = var(Dirichlet(c)) | ||
varDiri[:,2,2] = var(Dirichlet(d)) | ||
|
||
@test var(TensorDirichlet(tensorDiri)) == varDiri |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats quite a strange interface for var
@test cov(TensorDirichlet(tensorDiri)) == covTensorDiri | ||
end | ||
|
||
@test_broken "TensorDirichlet: ExponentialFamilyDistribution" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not write tests outside of @testitem
@test @inferred(prod(PreserveTypeProd(Distribution), d2, d3)) ≈ TensorDirichlet([1.2000000000000002 4.0 2.0 ; 3.3 5.0 1.1]) | ||
end | ||
|
||
@test_broken "TensorDirichlet: prod with ExponentialFamilyDistribution" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do not write tests outside of @testitem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm quite skeptical about this PR. First of all, is this distribution is from exponential family? If not, why should it belong to this package?
Second, why to do all these shenanigans with a[:, m, n]
when you could simply to an array of Dirichlet
distributions?
E.g.
a = [
Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ]),
Dirichlet([ 0.1, 0.9 ]), Dirichlet([ 0.1, 0.9 ])
]
and then later on simply
a[1, 2] # returns Dirichlet
I'm also not sure if the functionality is really correct, the logpdf
can't be correct right? Why tests didn't pick it up?
@bvdmitri This is a generalization of |
Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
Thank yo for the review @bvdmitri. I discussed this with @Nimrais, and I think it is indeed a member of the exponential family. |
I agree that working directly with Float-array, but only if you implement it correctly. I doubt that the current implementation is faster or more efficient because it creates a lot of slices and allocates new memory on every access, which is perhaps even slower than using arrays of Dirichlet distributions. You sometimes use |
@bvdmitri I'll work on performance, for now we just need a correct generalization of |
Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
…ponentialFamily.jl into DirichletTensor
If the performance is not a goal at this point, why not make an array of Dirichlet distributions? It's far easier to understand at least. |
function BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x) | ||
l = size(getnaturalparameters(ef)) | ||
values = [x[:,i] for i in CartesianIndices(Base.tail(size(x)))] | ||
## The element of the array should be the a categorical distribution (an vector of postive value that sum to 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a comment, this is the documentation make it a docstring, smt like this
"""
BayesBase.insupport(ef::ExponentialFamilyDistribution{TensorDirichlet}, x)
Check if the input `x` is within the support of a Tensor Dirichlet distribution.
Requirements:
- Each column of x must represent a valid categorical distribution (sum to 1, all values ≥ 0)
- Dimensions must match the natural parameters of the distribution
"""
Base.size(dist::TensorDirichlet) = size(dist.a) | ||
Base.eltype(::TensorDirichlet{T}) where {T} = T | ||
|
||
function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In one of the vague
methods (with a tuple), you specify the type of ones, while in another, you don't. Is there any reason why you aren't writing it like this?
function BayesBase.vague(::Type{<:TensorDirichlet}, dims::Int)
return TensorDirichlet(ones(Float64, dims, dims))
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific reason, I missed this difference.
Float64 seems to be the default type, so it should be the same.
I will add the Float64
argument for clarity.
Thanks
a::A | ||
end | ||
|
||
extract_collection(dist::TensorDirichlet) = [dist.a[:,i] for i in CartesianIndices(Base.tail(size(dist.a)))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extract_collection(dist::TensorDirichlet) = (view(dist.a, :, i) for i in CartesianIndices(Base.tail(size(dist.a))))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes with a '@view' it will be way better. I will take your suggestion.
@Nimrais Because in the end we do need a performant implementation of this, which would require the data to be stored the way Raphael coded it. It is a lot easier to implement performant versions of methods if the underlying data structure does not change anymore. |
Sure, but then it would be nice to see the proof that it works faster. (@Raphael-Tresor it's more or less comment for you for the next time) For example, a test that shows that the current way of implementing things is better than the naïve. Because if, at some point in time, it is starting to become slower, what is the point? For example, smt like this, but written carefully
|
@Nimrais you mean something like this?
|
Yeah, but we have more than only the In my example, because the
|
Dirichlet tensor performance
Reimplement entropy for TensorDirichlet
…ponentialFamily.jl into DirichletTensor
I am currently working on the PR. It is not ready for review, so I will convert it to a draft. |
Issue
The branch is not ExponentialFamily compatible yet. The pack unpack mechanism does not work.