Skip to content

Commit

Permalink
Merge pull request #19 from ReactiveBayes/pointmass
Browse files Browse the repository at this point in the history
Generalize pointmasses for abstractarray
  • Loading branch information
wouterwln authored Sep 17, 2024
2 parents 66019bf + f4a9ad7 commit 19f8147
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 59 deletions.
93 changes: 34 additions & 59 deletions src/densities/pointmass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ getpointmass(point::Union{Real,AbstractArray}) = point
BayesBase.variate_form(::Type{PointMass{T}}) where {T<:Real} = Univariate
BayesBase.variate_form(::Type{PointMass{V}}) where {T,V<:AbstractVector{T}} = Multivariate
BayesBase.variate_form(::Type{PointMass{M}}) where {T,M<:AbstractMatrix{T}} = Matrixvariate
function BayesBase.variate_form(::Type{PointMass{M}}) where {T,N,M<:AbstractArray{T,N}}
return ArrayLikeVariate{N}
end
BayesBase.variate_form(::Type{PointMass{U}}) where {T,U<:UniformScaling{T}} = Matrixvariate

function BayesBase.mean(fn::F, distribution::PointMass) where {F<:Function}
Expand Down Expand Up @@ -72,102 +75,74 @@ Base.precision(::PointMass{T}) where {T<:Real} = convert(T, Inf)
Base.ndims(::PointMass{T}) where {T<:Real} = 1
Base.eltype(::PointMass{T}) where {T<:Real} = T

# AbstractVector-based multivariate point mass
# AbstractArray-based multivariate point mass

function BayesBase.insupport(
distribution::PointMass{V}, x::AbstractVector
) where {T<:Real,V<:AbstractVector{T}}
distribution::PointMass{V}, x::AbstractArray{T,N}
) where {T<:Real,N,V<:AbstractArray{T,N}}
return x == getpointmass(distribution)
end

function BayesBase.pdf(
distribution::PointMass{V}, x::AbstractVector
) where {T<:Real,V<:AbstractVector{T}}
distribution::PointMass{V}, x::AbstractArray{T,N}
) where {T<:Real,N,V<:AbstractArray{T,N}}
return insupport(distribution, x) ? one(T) : zero(T)
end

function BayesBase.logpdf(
distribution::PointMass{V}, x::AbstractVector
) where {T<:Real,V<:AbstractVector{T}}
distribution::PointMass{V}, x::AbstractArray{T,N}
) where {T<:Real,N,V<:AbstractArray{T,N}}
return insupport(distribution, x) ? zero(T) : convert(T, -Inf)
end

function BayesBase.mean(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
function BayesBase.mean(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T}}
return getpointmass(distribution)
end
function BayesBase.mode(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
function BayesBase.mode(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T}}
return mean(distribution)
end
function BayesBase.var(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
return zeros(T, (ndims(distribution),))
function BayesBase.var(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T}}
return zeros(T, ndims(distribution))
end
function BayesBase.std(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
return zeros(T, (ndims(distribution),))

function BayesBase.std(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T}}
return zeros(T, ndims(distribution))
end
function BayesBase.cov(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}

# For vectors, covariances and probvec are defined
function BayesBase.cov(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T,1}}
return zeros(T, (ndims(distribution), ndims(distribution)))
end

function BayesBase.probvec(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
function BayesBase.probvec(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T,1}}
return mean(distribution)
end

function Base.precision(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
return one(T) ./ cov(distribution)
function BayesBase.cov(distribution::PointMass{M}) where {T<:Real,N,M<:AbstractArray{T,N}}
return error("cov(::PointMass{ <: $M }) is not defined")
end

function Base.ndims(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
return length(mean(distribution))
function BayesBase.probvec(
distribution::PointMass{M}
) where {T<:Real,N,M<:AbstractArray{T,N}}
return error("probvec(::PointMass{ <: $M }) is not defined")
end

Base.eltype(::PointMass{V}) where {T<:Real,V<:AbstractVector{T}} = T

# AbstractMatrix-based matrixvariate point mass

function BayesBase.insupport(
distribution::PointMass{M}, x::AbstractMatrix
) where {T<:Real,M<:AbstractMatrix{T}}
return x == getpointmass(distribution)
end
function BayesBase.pdf(
distribution::PointMass{M}, x::AbstractMatrix
) where {T<:Real,M<:AbstractMatrix{T}}
return insupport(distribution, x) ? one(T) : zero(T)
end
function BayesBase.logpdf(
distribution::PointMass{M}, x::AbstractMatrix
) where {T<:Real,M<:AbstractMatrix{T}}
return insupport(distribution, x) ? zero(T) : convert(T, -Inf)
function Base.precision(distribution::PointMass{V}) where {T<:Real,V<:AbstractArray{T}}
return one(T) ./ cov(distribution)
end

function BayesBase.mean(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return getpointmass(distribution)
end
function BayesBase.mode(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return mean(distribution)
end
function BayesBase.var(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return zeros(T, ndims(distribution))
end
function BayesBase.std(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return zeros(T, ndims(distribution))
end
function BayesBase.cov(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return error("cov(::PointMass{ <: AbstractMatrix }) is not defined")
end
# We need this function for backwards compatibility

function BayesBase.probvec(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return error("probvec(::PointMass{ <: AbstractMatrix }) is not defined")
function Base.ndims(distribution::PointMass{V}) where {T<:Real,V<:AbstractVector{T}}
return length(mean(distribution))
end

function Base.precision(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
return one(T) ./ cov(distribution)
end
function Base.ndims(distribution::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}}
function Base.ndims(distribution::PointMass{M}) where {T<:Real,N,M<:AbstractArray{T,N}}
return size(mean(distribution))
end

Base.eltype(::PointMass{M}) where {T<:Real,M<:AbstractMatrix{T}} = T
Base.eltype(::PointMass{V}) where {T<:Real,N,V<:AbstractArray{T,N}} = T

# UniformScaling-based matrixvariate point mass

Expand Down
64 changes: 64 additions & 0 deletions test/densities/pointmass_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end
@testitem "Vector-based PointMass" begin
using SpecialFunctions: loggamma
using TinyHugeNumbers
using BayesBase

for T in (Float16, Float32, Float64, BigFloat), N in (5, 10)
vector = rand(T, N)
Expand Down Expand Up @@ -172,6 +173,69 @@ end
end
end

@testitem "Tensor-based PointMass" begin
using SpecialFunctions: loggamma
using TinyHugeNumbers
using Distributions

for D in [3, 4, 5]
for T in (Float16, Float32, Float64, BigFloat), N in (5, 10)
tensor = rand(T, ntuple(_ -> N, D))
dist = PointMass(tensor)

@test variate_form(typeof(dist)) === Distributions.ArrayLikeVariate{D}
@test dist[2] === tensor[2]
@test dist[3] === tensor[3]
@test dist[ntuple(_ -> 3, D)...] === tensor[ntuple(_ -> 3, D)...]
for i in 1:D
@test size(dist, i) === size(tensor, i)
end
@test_throws BoundsError dist[N^(D + 1)]
@test_throws BoundsError dist[ntuple(_ -> N + 1, D)...]

@test insupport(dist, tensor)
@test !insupport(dist, tensor .+ tiny)
@test !insupport(dist, tensor .- tiny)

@test @inferred(T, pdf(dist, tensor)) == one(T)
@test @inferred(T, pdf(dist, tensor .+ tiny)) == zero(T)
@test @inferred(T, pdf(dist, tensor .- tiny)) == zero(T)

@test @inferred(T, logpdf(dist, tensor)) == zero(T)
@test @inferred(T, logpdf(dist, tensor .+ tiny)) == convert(T, -Inf)
@test @inferred(T, logpdf(dist, tensor .- tiny)) == convert(T, -Inf)

for i in 1:(D - 1)
@test_throws MethodError insupport(dist, ones(T, ntuple(_ -> 2, i)...))
@test_throws MethodError pdf(dist, ones(T, ntuple(_ -> 2, i)...))
@test_throws MethodError logpdf(dist, ones(T, ntuple(_ -> 2, i)...))
end

@test (@inferred entropy(dist)) == BayesBase.MinusInfinity(T)

@test @inferred(AbstractArray{D,T}, mean(dist)) == tensor
@test @inferred(AbstractArray{D,T}, mode(dist)) == tensor
@test @inferred(AbstractArray{D,T}, var(dist)) == zeros(ntuple(_ -> N, D)...)
@test @inferred(AbstractArray{D,T}, std(dist)) == zeros(ntuple(_ -> N, D)...)
@test @inferred(Tuple{Int,Int,Int}, ndims(dist)) == ntuple(_ -> N, D)
@test @inferred(Type{T}, eltype(dist)) == T

@test_throws ErrorException cov(dist)
@test_throws ErrorException precision(dist)

@test_throws ErrorException probvec(dist)
@test @inferred(
AbstractArray{D,T}, mean(Base.Broadcast.BroadcastFunction(log), dist)
) == log.(tensor)
@test @inferred(
AbstractArray{D,T}, mean(Base.Broadcast.BroadcastFunction(loggamma), dist)
) == loggamma.(tensor)

@test_throws MethodError mean(loggamma, dist)
end
end
end

@testitem "UniformScaling-based PointMass" begin
using LinearAlgebra, TinyHugeNumbers

Expand Down

0 comments on commit 19f8147

Please sign in to comment.