Skip to content

Commit

Permalink
feat: add samplelist
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 11, 2023
1 parent ee9c159 commit f3ee026
Show file tree
Hide file tree
Showing 8 changed files with 1,439 additions and 19 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ version = "1.0.0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -18,14 +20,16 @@ TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04"
[compat]
Distributions = "0.25"
DomainSets = "0.7"
LinearAlgebra = "1.9"
Random = "1.9"
SpecialFunctions = "2.3"
Statistics = "1.9"
StatsAPI = "1.7"
StatsBase = "0.34"
StaticArrays = "1.6"
LoopVectorization = "0.12"
StatsFuns = "1.3"
TinyHugeNumbers = "1.0"
LinearAlgebra = "1.9"
julia = "1.9"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/BayesBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ include("prod.jl")

include("densities/pointmass.jl")
include("densities/function.jl")
include("densities/samplelist.jl")
include("densities/mixture.jl")
include("densities/factorizedjoint.jl")

Expand Down
3 changes: 0 additions & 3 deletions src/densities/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ getdomain(dist::AbstractContinuousGenericLogPdf) = dist.domain
getlogpdf(dist::AbstractContinuousGenericLogPdf) = dist.logpdf

BayesBase.value_support(::Type{<:AbstractContinuousGenericLogPdf}) = Continuous
BayesBase.value_support(::AbstractContinuousGenericLogPdf) = Continuous

# We throw an error on purpose, since we do not want to use `AbstractContinuousGenericLogPdf` much without approximations
# We want to encourage a user to use approximate generic log-pdfs as much as possible instead
Expand Down Expand Up @@ -92,7 +91,6 @@ function ContinuousUnivariateLogPdf(f::Function)
end

BayesBase.variate_form(::Type{<:ContinuousUnivariateLogPdf}) = Univariate
BayesBase.variate_form(::ContinuousUnivariateLogPdf) = Univariate

function BayesBase.promote_variate_type(
::Type{Univariate}, ::Type{AbstractContinuousGenericLogPdf}
Expand Down Expand Up @@ -172,7 +170,6 @@ struct ContinuousMultivariateLogPdf{D<:DomainSets.Domain,F} <:
end

BayesBase.variate_form(::Type{<:ContinuousMultivariateLogPdf}) = Multivariate
BayesBase.variate_form(::ContinuousMultivariateLogPdf) = Multivariate

function BayesBase.promote_variate_type(
::Type{Multivariate}, ::Type{AbstractContinuousGenericLogPdf}
Expand Down
8 changes: 4 additions & 4 deletions src/densities/pointmass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ end
getpointmass(distribution::PointMass) = distribution.point
getpointmass(point::Union{Real,AbstractArray}) = point

BayesBase.variate_form(::PointMass{T}) where {T<:Real} = Univariate
BayesBase.variate_form(::PointMass{V}) where {T,V<:AbstractVector{T}} = Multivariate
BayesBase.variate_form(::PointMass{M}) where {T,M<:AbstractMatrix{T}} = Matrixvariate
BayesBase.variate_form(::PointMass{U}) where {T,U<:UniformScaling{T}} = Matrixvariate
BayesBase.variate_form(::Type{PointMass{T}}) where {T<:Real} = Univariate
BayesBase.variate_form(::Type{PointMass{V}}) where {T,V<:AbstractVector{T}} = Multivariate
BayesBase.variate_form(::Type{PointMass{M}}) where {T,M<:AbstractMatrix{T}} = Matrixvariate
BayesBase.variate_form(::Type{PointMass{U}}) where {T,U<:UniformScaling{T}} = Matrixvariate

function BayesBase.mean(fn::F, distribution::PointMass) where {F<:Function}
return fn(mean(distribution))
Expand Down
Loading

0 comments on commit f3ee026

Please sign in to comment.