Skip to content

Commit

Permalink
Merge pull request #61 from JuliaGaussianProcesses/tgf/rework_categor…
Browse files Browse the repository at this point in the history
…ical

Rework categorical to allow multiple variants
  • Loading branch information
theogf authored Jan 31, 2022
2 parents 40d2790 + 316d739 commit 05abf88
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GPLikelihoods"
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.7"
version = "0.3.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
1 change: 1 addition & 0 deletions src/GPLikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export BernoulliLikelihood,
GammaLikelihood
export Link,
ChainLink,
BijectiveSimplexLink,
ExpLink,
LogLink,
InvLink,
Expand Down
33 changes: 26 additions & 7 deletions src/likelihoods/categorical.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,41 @@
"""
CategoricalLikelihood(l=softmax)
CategoricalLikelihood(l=BijectiveSimplexLink(softmax))
Categorical likelihood is to be used if we assume that the
uncertainity associated with the data follows a Categorical distribution.
uncertainty associated with the data follows a [Categorical distribution](https://en.wikipedia.org/wiki/Categorical_distribution).
Assuming a distribution with `n` categories:
## `n-1` inputs (bijective link)
One can work with a bijective transformation by wrapping a link (like `softmax`)
into a [`BijectiveSimplexLink`](@ref) and only needs `n-1` inputs:
```math
p(y|f_1, f_2, \\dots, f_{n-1}) = \\operatorname{Categorical}(y | l(f_1, f_2, \\dots, f_{n-1}, 0))
```
Given an `AbstractVector` ``[f_1, f_2, ..., f_{n-1}]``, returns a `Categorical` distribution,
with probabilities given by ``l(f_1, f_2, ..., f_{n-1}, 0)``.
The default constructor is a bijective link around `softmax`.
## `n` inputs (non-bijective link)
One can also pass directly the inputs without concatenating a `0`:
```math
p(y|f_1, f_2, \\dots, f_n) = \\operatorname{Categorical}(y | l(f_1, f_2, \\dots, f_n))
```
This variant is over-parametrized, as there are `n-1` independent parameters
embedded in a `n` dimensional parameter space.
For more details, see the end of the section of this [Wikipedia link](https://en.wikipedia.org/wiki/Exponential_family#Table_of_distributions)
where it corresponds to Variant 1 and 2.
"""
struct CategoricalLikelihood{Tl<:AbstractLink} <: AbstractLikelihood
invlink::Tl
end

CategoricalLikelihood(l=softmax) = CategoricalLikelihood(link(l))
CategoricalLikelihood(l=BijectiveSimplexLink(softmax)) = CategoricalLikelihood(link(l))

(l::CategoricalLikelihood)(f::AbstractVector{<:Real}) = Categorical(l.invlink(vcat(f, 0)))
function (l::CategoricalLikelihood)(f::AbstractVector{<:Real})
return Categorical(l.invlink(f))
end

function (l::CategoricalLikelihood)(fs::AbstractVector)
return Product(Categorical.(l.invlink.(vcat.(fs, 0))))
return Product(Categorical.(l.invlink.(fs)))
end
17 changes: 17 additions & 0 deletions src/links.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ link(l::AbstractLink) = l

Base.inv(l::Link) = Link(InverseFunctions.inverse(l.f))

"""
BijectiveSimplexLink(link)
Wrapper to preprocess the inputs by adding a `0` at the end before passing it to
the link `link`.
This is a necessary step to work with simplices.
For example with the [`SoftMaxLink`](@ref), to obtain a `n`-simplex leading to
`n+1` categories for the [`CategoricalLikelihood`](@ref),
one needs to pass `n+1` latent GP.
However, by wrapping the link into a `BijectiveSimplexLink`, only `n` latent are needed.
"""
struct BijectiveSimplexLink{L} <: AbstractLink
link::L
end

(l::BijectiveSimplexLink)(f::AbstractVector{<:Real}) = l.link(vcat(f, 0))

# alias
const LogLink = Link{typeof(log)}
const ExpLink = Link{typeof(exp)}
Expand Down
14 changes: 9 additions & 5 deletions test/likelihoods/categorical.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
@testset "CategoricalLikelihood" begin
for args in ((), (softmax,), (SoftMaxLink(),))
@test CategoricalLikelihood(args...) isa CategoricalLikelihood{SoftMaxLink}
end
@test CategoricalLikelihood() isa
CategoricalLikelihood{<:GPLikelihoods.BijectiveSimplexLink}

@test CategoricalLikelihood(softmax) isa CategoricalLikelihood{SoftMaxLink}
@test CategoricalLikelihood(SoftMaxLink()) isa CategoricalLikelihood{SoftMaxLink}

lik = CategoricalLikelihood()
OUT_DIM = 4
test_interface(lik, Categorical, OUT_DIM)
lik_bijective = CategoricalLikelihood()
test_interface(lik_bijective, Categorical, OUT_DIM)
lik_nonbijective = CategoricalLikelihood(softmax)
test_interface(lik_nonbijective, Categorical, OUT_DIM)
end
5 changes: 5 additions & 0 deletions test/links.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
l = GPLikelihoods.link(ExpLink())
@test l == ExpLink()

## SimplexBijective link
l = SoftMaxLink()
sbl = BijectiveSimplexLink(l)
@test sbl(xs) == l(vcat(xs, 0))

# Log
l = LogLink()
@test l(x) == log(x)
Expand Down

0 comments on commit 05abf88

Please sign in to comment.