From 1f5b041d242ecba97d41df84914a2fcf9aa29e83 Mon Sep 17 00:00:00 2001 From: Bagaev Dmitry Date: Tue, 10 Oct 2023 14:53:36 +0200 Subject: [PATCH] move stuff from ExpFamily.jl --- Project.toml | 27 +- README.md | 7 + docs/src/index.md | 7 + src/BayesBase.jl | 69 +++- src/densities/factorizedjoint.jl | 42 ++ src/prod.jl | 501 ++++++++++++++++++++++++ src/promotion.jl | 179 +++++++++ src/statsfuns.jl | 191 +++++++++ test/densities/factorizedjoint_tests.jl | 42 ++ test/prod_setuptests.jl | 100 +++++ test/prod_tests.jl | 170 ++++++++ test/promotion_setuptests.jl | 26 ++ test/promotion_tests.jl | 122 ++++++ test/runtests.jl | 16 +- test/statsfuns_tests.jl | 29 ++ 15 files changed, 1518 insertions(+), 10 deletions(-) create mode 100644 src/densities/factorizedjoint.jl create mode 100644 src/prod.jl create mode 100644 src/promotion.jl create mode 100644 src/statsfuns.jl create mode 100644 test/densities/factorizedjoint_tests.jl create mode 100644 test/prod_setuptests.jl create mode 100644 test/prod_tests.jl create mode 100644 test/promotion_setuptests.jl create mode 100644 test/promotion_tests.jl create mode 100644 test/statsfuns_tests.jl diff --git a/Project.toml b/Project.toml index 5d9f7d0..b14cedd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,32 @@ name = "BayesBase" uuid = "b4ee3484-f114-42fe-b91c-797d54a0c67e" -authors = ["Bagaev Dmitry and contributors"] -version = "1.0.0-DEV" +authors = ["Bagaev Dmitry and contributors"] +version = "1.0.0" + +[deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TinyHugeNumbers = "783c9a47-75a3-44ac-a16b-f1ab7b3acf04" [compat] -julia = "1" +Distributions = "0.25" +Random = "1.9" +Statistics = "1.9" +StatsAPI = "1.7" +StatsBase = "0.34" +TinyHugeNumbers = "1.0" +julia = "1.9" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +CpuId = "adafc99b-e345-5852-983c-f28acb93d879" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Aqua", "CpuId", "Test", "ReTestItems", "LinearAlgebra", "StableRNGs"] diff --git a/README.md b/README.md index e027e92..2fceebc 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,10 @@ [![Coverage](https://codecov.io/gh/biaslab/BayesBase.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/biaslab/BayesBase.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/B/BayesBase.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/B/BayesBase.html) + +`BayesBase` is a package that serves as an umbrella, defining, exporting, and re-exporting methods essential for Bayesian statistics. +The `BayesBase` depends on [`Distributions`](https://github.com/JuliaStats/Distributions.jl), [`StatsBase`](https://github.com/JuliaStats/StatsBase.jl) and [`StatsAPI`](https://github.com/JuliaStats/StatsAPI.jl). + +Related projects: + +- [`ExponentialFamily`](https://github.com/biaslab/ExponentialFamily.jl) \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index 61aee00..62713cd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,6 +9,13 @@ Documentation for [BayesBase](https://github.com/biaslab/BayesBase.jl). ```@index ``` +# List of exported functions + +```@docs +StatsAPI.params +Statistics.mean +``` + ```@autodocs Modules = [BayesBase] ``` diff --git a/src/BayesBase.jl b/src/BayesBase.jl index 248c795..4177aaf 100644 --- a/src/BayesBase.jl +++ b/src/BayesBase.jl @@ -1,5 +1,72 @@ module BayesBase -# Write your package code here. +using TinyHugeNumbers + +using StatsAPI, StatsBase, Statistics, Distributions, Random + +using StatsAPI: params + +export params + +using Statistics: mean, median, std, var, cov + +export mean, median, std, var, cov + +using StatsBase: mode, entropy + +export mode, entropy + +using Distributions: + failprob, + succprob, + insupport, + shape, + scale, + rate, + invcov, + pdf, + logpdf, + logdetcov, + VariateForm, + ValueSupport, + Distribution, + Univariate, + Multivariate, + Matrixvariate, + variate_form, + value_support + +export failprob, + succprob, + insupport, + shape, + scale, + rate, + invcov, + pdf, + logpdf, + logdetcov, + VariateForm, + ValueSupport, + Distribution, + Univariate, + Multivariate, + Matrixvariate, + variate_form, + value_support + +using Base: precision, eltype, convert, length, isapprox + +export precision, eltype, convert, length, isapprox + +using Random: rand, rand! + +export rand, rand! + +include("statsfuns.jl") +include("promotion.jl") +include("prod.jl") + +include("densities/factorizedjoint.jl") end diff --git a/src/densities/factorizedjoint.jl b/src/densities/factorizedjoint.jl new file mode 100644 index 0000000..d5ab56a --- /dev/null +++ b/src/densities/factorizedjoint.jl @@ -0,0 +1,42 @@ +export FactorizedJoint + +""" + FactorizedJoint + +`FactorizedJoint` represents a joint distribution of independent random variables. +Use `component()` function or square-brackets indexing to access the marginal distribution for individual variables. +Use `components()` function to get a tuple of multipliers. +""" +struct FactorizedJoint{T} + multipliers::T +end + +BayesBase.components(joint::FactorizedJoint) = joint.multipliers + +Base.@propagate_inbounds function BayesBase.component(joint::FactorizedJoint, i::Int) + return getindex(joint, i) +end +Base.@propagate_inbounds function Base.getindex(joint::FactorizedJoint, i::Int) + return getindex(components(joint), i) +end + +BayesBase.length(joint::FactorizedJoint) = length(joint.multipliers) + +function BayesBase.isapprox(x::FactorizedJoint, y::FactorizedJoint; kwargs...) + return length(x) === length(y) && all( + tuple -> isapprox(tuple[1], tuple[2]; kwargs...), + zip(components(x), components(y)), + ) +end + +BayesBase.entropy(joint::FactorizedJoint) = mapreduce(entropy, +, components(joint)) + +function BayesBase.paramfloattype(joint::FactorizedJoint) + return BayesBase.paramfloattype(BayesBase.components(joint)) +end + +function BayesBase.convert_paramfloattype(::Type{T}, joint::FactorizedJoint) where {T} + return FactorizedJoint( + map(e -> BayesBase.convert_paramfloattype(T, joint), BayesBase.components(joint)) + ) +end \ No newline at end of file diff --git a/src/prod.jl b/src/prod.jl new file mode 100644 index 0000000..b17aee4 --- /dev/null +++ b/src/prod.jl @@ -0,0 +1,501 @@ + +import Distributions: VariateForm, ValueSupport, variate_form, value_support, support +import Base: prod, prod!, show, showerror + +export prod, + default_prod_rule, + fuse_supports, + ClosedProd, + PreserveTypeProd, + PreserveTypeLeftProd, + PreserveTypeRightProd, + GenericProd, + ProductOf, + LinearizedProductOf + +""" + UnspecifiedProd + +A strategy for the `prod` function, which does not compute the `prod`, but instead fails in run-time and prints a descriptive error message. + +See also: [`prod`](@ref), [`ClosedProd`](@ref), [`GenericProd`](@ref) +""" +struct UnspecifiedProd end + +""" + prod(strategy, left, right) + +`prod` function is used to find a product of two probability distributions (or any other objects) over same variable (e.g. 𝓝(x|μ_1, σ_1) × 𝓝(x|μ_2, σ_2)). +There are multiple strategies for prod function, e.g. `ClosedProd`, `GenericProd` or `PreserveTypeProd`. + +# Examples: + +```jldoctest +julia> product = prod(PreserveTypeProd(Distribution), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0)) +NormalWeightedMeanPrecision{Float64}(xi=0.0, w=2.0) + +julia> mean(product), var(product) +(0.0, 0.5) +``` + +```jldoctest +julia> product = prod(PreserveTypeProd(NormalMeanVariance), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0)) +NormalMeanVariance{Float64}(μ=0.0, v=0.5) + +julia> mean(product), var(product) +(0.0, 0.5) +``` + +```jldoctest +julia> product = prod(PreserveTypeProd(ExponentialFamilyDistribution), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0)) +ExponentialFamily(NormalMeanVariance) + +julia> mean(product), var(product) +(0.0, 0.5) +``` + +See also: [`default_prod_rule`](@ref), [`ClosedProd`](@ref), [`PreserveTypeProd`](@ref), [`GenericProd`](@ref) +""" +function Base.prod(strategy::UnspecifiedProd, left, right) + throw(MethodError(prod, (strategy, left, right))) +end + +Base.prod(::UnspecifiedProd, ::Missing, right) = right +Base.prod(::UnspecifiedProd, left, ::Missing) = left +Base.prod(::UnspecifiedProd, ::Missing, ::Missing) = missing + +""" + default_prod_rule(::Type, ::Type) + +Returns the most suitable `prod` rule for two given distribution types. +Returns `UnspecifiedProd` by default. + +See also: [`prod`](@ref), [`ClosedProd`](@ref), [`GenericProd`](@ref) +""" +default_prod_rule(::Type, ::Type) = UnspecifiedProd() + +function default_prod_rule(not_a_type, ::Type{R}) where {R} + return default_prod_rule(typeof(not_a_type), R) +end + +function default_prod_rule(::Type{L}, not_a_type) where {L} + return default_prod_rule(L, typeof(not_a_type)) +end + +function default_prod_rule(not_a_type_left, not_a_type_right) + return default_prod_rule(typeof(not_a_type_left), typeof(not_a_type_right)) +end + +""" + PreserveTypeProd{T} + +`PreserveTypeProd` is one of the strategies for `prod` function. This strategy constraint an output of a prod to be in some specific form. +By default it uses the strategy from `default_prod_rule` and converts the output to the prespecified type but can be overwritten +for some distributions for better performance. + +```jldoctest +julia> product = prod(PreserveTypeProd(NormalMeanVariance), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0)) +NormalMeanVariance{Float64}(μ=0.0, v=0.5) + +julia> mean(product), var(product) +(0.0, 0.5) +``` + +See also: [`prod`](@ref), [`ClosedProd`](@ref), [`PreserveTypeLeftProd`](@ref), [`PreserveTypeRightProd`](@ref), [`GenericProd`](@ref) +""" +struct PreserveTypeProd{T} end + +PreserveTypeProd(::Type{T}) where {T} = PreserveTypeProd{T}() + +function Base.prod(::PreserveTypeProd{T}, left, right) where {T} + return convert(T, prod(symmetric_default_prod_rule(left, right), left, right)) +end + +Base.prod(::PreserveTypeProd, ::Missing, right) = right +Base.prod(::PreserveTypeProd, left, ::Missing) = left +Base.prod(::PreserveTypeProd, ::Missing, ::Missing) = missing + +""" + PreserveTypeLeftProd + +An alias for the `PreserveTypeProd(L)` where `L` is the type of the `left` argument of the `prod` function. + +```jldoctest +julia> product = prod(PreserveTypeLeftProd(), NormalMeanVariance(-1.0, 1.0), NormalMeanPrecision(1.0, 1.0)) +NormalMeanVariance{Float64}(μ=0.0, v=0.5) + +julia> mean(product), var(product) +(0.0, 0.5) +``` + +See also: [`prod`](@ref), [`PreserveTypeProd`](@ref), [`PreserveTypeRightProd`](@ref), [`GenericProd`](@ref) +""" +struct PreserveTypeLeftProd end + +function Base.prod(::PreserveTypeLeftProd, left::L, right) where {L} + return prod(PreserveTypeProd(L), left, right) +end + +""" + PreserveTypeRightProd + +An alias for the `PreserveTypeProd(R)` where `R` is the type of the `right` argument of the `prod` function. + +```jldoctest +julia> product = prod(PreserveTypeRightProd(), NormalMeanVariance(-1.0, 1.0), NormalMeanPrecision(1.0, 1.0)) +NormalMeanPrecision{Float64}(μ=0.0, w=2.0) + +julia> mean(product), var(product) +(0.0, 0.5) +``` + +See also: [`prod`](@ref), [`PreserveTypeProd`](@ref), [`PreserveTypeLeftProd`](@ref), [`GenericProd`](@ref) +""" +struct PreserveTypeRightProd end + +function Base.prod(::PreserveTypeRightProd, left, right::R) where {R} + return prod(PreserveTypeProd(R), left, right) +end + +""" + ClosedProd + +`ClosedProd` is one of the strategies for `prod` function. For example, if both inputs are of type `Distribution`, then `ClosedProd` would fallback to `PreserveTypeProd(Distribution)`. + +See also: [`prod`](@ref), [`PreserveTypeProd`](@ref), [`GenericProd`](@ref) +""" +struct ClosedProd end + +Base.prod(::ClosedProd, ::Missing, right) = right +Base.prod(::ClosedProd, left, ::Missing) = left +Base.prod(::ClosedProd, ::Missing, ::Missing) = missing + +# We assume that we want to preserve the `Distribution` when working with two `Distribution`s +Base.prod(::ClosedProd, left::Distribution, right::Distribution) = prod(PreserveTypeProd(Distribution), left, right) + +# This is a hidden prod strategy to ensure symmetricity in the `default_prod_rule`. +# Most of the automatic prod rule resolution relies on the `symmetric_default_prod_rule` instead of just `default_prod_rule` +# The `symmetric_default_prod_rule` will adjust the prod rule in case if there is an available prod rule with swapped arguments +struct SwapArgumentsProd{S} + strategy::S +end + +Base.prod(swap::SwapArgumentsProd, left, right) = prod(swap.strategy, right, left) + +function symmetric_default_prod_rule(left, right) + return symmetric_default_prod_rule( + default_prod_rule(left, right), default_prod_rule(right, left), left, right + ) +end + +symmetric_default_prod_rule(strategy1, strategy2, left, right) = strategy1 +symmetric_default_prod_rule(strategy1, ::UnspecifiedProd, left, right) = strategy1 +function symmetric_default_prod_rule(::UnspecifiedProd, strategy2, left, right) + return SwapArgumentsProd(strategy2) +end +function symmetric_default_prod_rule(::UnspecifiedProd, ::UnspecifiedProd, left, right) + return UnspecifiedProd() +end + +""" + fuse_supports(left, right) + +Fuse supports of two distributions of `left` and `right`. +By default, checks that the supports are identical and throws an error otherwise. +Can implement specific fusions for specific distributions. + +See also: [`prod`](@ref), [`ProductOf`](@ref) +""" +function fuse_supports(left, right) + if !isequal(support(left), support(right)) + error("Cannot form a `ProductOf` $(left) & `$(right)`. Support is incompatible.") + end + return support(left) +end + +""" + ProductOf + +A generic structure representing a product of two distributions. +Can be viewed as a tuple of `(left, right)`. +Does not check nor supports neither variate forms during the creation stage. +Uses the `fuse_support` function to fuse supports of two different distributions. + +This object does not define any statistical properties (such as `mean` or `var` etc) and cannot be used as a distribution explicitly. +Instead, it must be further approximated as a member of some other distribution. + +See also: [`prod`](@ref), [`GenericProd`](@ref), [`ExponentialFamily.fuse_supports`](@ref) +""" +struct ProductOf{L,R} + left::L + right::R +end + +getleft(product::ProductOf) = product.left +getright(product::ProductOf) = product.right + +function Base.:(==)(left::ProductOf, right::ProductOf) + return (getleft(left) == getleft(right)) && (getright(left) == getright(right)) +end + +function Base.show(io::IO, product::ProductOf) + return print(io, "ProductOf(", getleft(product), ",", getright(product), ")") +end + +function Distributions.support(product::ProductOf) + return fuse_supports(getleft(product), getright(product)) +end + +Distributions.pdf(product::ProductOf, x) = exp(logpdf(product, x)) + +function Distributions.logpdf(product::ProductOf, x) + return Distributions.logpdf(getleft(product), x) + + Distributions.logpdf(getright(product), x) +end + +Distributions.variate_form(::P) where {P<:ProductOf} = variate_form(P) + +function Distributions.variate_form(::Type{ProductOf{L,R}}) where {L,R} + return _check_product_variate_form(variate_form(L), variate_form(R)) +end + +_check_product_variate_form(::Type{F}, ::Type{F}) where {F<:VariateForm} = F + +function _check_product_variate_form( + ::Type{F1}, ::Type{F2} +) where {F1<:VariateForm,F2<:VariateForm} + return error( + "`ProductOf` has different variate forms for left ($F1) and right ($F2) entries." + ) +end + +Distributions.value_support(::P) where {P<:ProductOf} = value_support(P) + +function Distributions.value_support(::Type{ProductOf{L,R}}) where {L,R} + return _check_product_value_support(value_support(L), value_support(R)) +end + +_check_product_value_support(::Type{S}, ::Type{S}) where {S<:ValueSupport} = S + +function _check_product_value_support( + ::Type{S1}, ::Type{S2} +) where {S1<:ValueSupport,S2<:ValueSupport} + return error( + "`ProductOf` has different value supports for left ($S1) and right ($S2) entries." + ) +end + +""" + GenericProd + +`GenericProd` is one of the strategies for `prod` function. This strategy does always produces a result, +even if the closed form product is not availble, in which case simply returns the `ProductOf` object. `GenericProd` sometimes +fallbacks to the `default_prod_rule` which it may or may not use under some circumstances. +For example if the `default_prod_rule` is `ClosedProd` - `GenericProd` will try to optimize the tree with +analytical closed solutions (if possible). + +See also: [`prod`](@ref), [`ProductOf`](@ref), [`ClosedProd`](@ref), [`PreserveTypeProd`](@ref), [`default_prod_rule`](@ref) +""" +struct GenericProd end + +Base.show(io::IO, ::GenericProd) = print(io, "GenericProd()") + +Base.prod(::GenericProd, ::Missing, right) = right +Base.prod(::GenericProd, left, ::Missing) = left +Base.prod(::GenericProd, ::Missing, ::Missing) = missing + +function Base.prod(::GenericProd, left::L, right::R) where {L,R} + return prod(GenericProd(), symmetric_default_prod_rule(L, R), left, right) +end + +Base.prod(::GenericProd, specified_prod, left, right) = prod(specified_prod, left, right) +Base.prod(::GenericProd, ::UnspecifiedProd, left, right) = ProductOf(left, right) + +# Try to fuse the tree with analytical solutions (if possible) +# Case (L × R) × T +function Base.prod(::GenericProd, left::ProductOf{L,R}, right::T) where {L,R,T} + return prod( + GenericProd(), + symmetric_default_prod_rule(L, T), + symmetric_default_prod_rule(R, T), + left, + right, + ) +end + +# (L × R) × T cannot be fused, simply return the `ProductOf` +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left::ProductOf, right +) + return ProductOf(left, right) +end + +# (L × R) × T can be fused efficiently as (L × T) × R, because L × T has defined the `something` default prod +function Base.prod(::GenericProd, something, ::UnspecifiedProd, left::ProductOf, right) + return ProductOf(prod(something, getleft(left), right), getright(left)) +end + +# (L × R) × T can be fused efficiently as L × (R × T), because R × T has defined the `something` default prod +function Base.prod(::GenericProd, ::UnspecifiedProd, something, left::ProductOf, right) + return ProductOf(getleft(left), prod(something, getright(left), right)) +end + +# (L × R) × T can be fused efficiently as L × (R × T), because both L × T and R × T has defined the `something` default prod, but we choose R × T +function Base.prod(::GenericProd, _, something, left::ProductOf, right) + return ProductOf(getleft(left), prod(something, getright(left), right)) +end + +# Case T × (L × R) +function Base.prod(::GenericProd, left::T, right::ProductOf{L,R}) where {L,R,T} + return prod( + GenericProd(), + symmetric_default_prod_rule(T, L), + symmetric_default_prod_rule(T, R), + left, + right, + ) +end + +# T × (L × R) cannot be fused, simply return the `ProductOf` +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left, right::ProductOf +) + return ProductOf(left, right) +end + +# T × (L × R) can be fused efficiently as (T × L) × R, because T × L has defined the `something` default prod +function Base.prod(::GenericProd, something, ::UnspecifiedProd, left, right::ProductOf) + return ProductOf(prod(something, left, getleft(right)), getright(right)) +end + +# T × (L × R) can be fused efficiently as L × (T × R), because T × R has defined the `something` default prod +function Base.prod(::GenericProd, ::UnspecifiedProd, something, left, right::ProductOf) + return ProductOf(getleft(right), prod(something, left, getright(right))) +end + +# T × (L × R) can be fused efficiently as L × (T × R), because both T × L and T × R has defined the `something` default prod, but we choose T × L +function Base.prod(::GenericProd, something, _, left, right::ProductOf) + return ProductOf(prod(something, left, getleft(right)), getright(right)) +end + +""" + LinearizedProductOf + +An efficient __linearized__ implementation of product of multiple distributions. +This structure prevents `ProductOf` tree from growing too much in case of identical objects. +This trick significantly reduces Julia compilation times when closed product rules are not available but distributions are of the same type. +Essentially this structure linearizes leaves of the `ProductOf` tree in case if it sees objects of the same type (via dispatch). + +See also: [`ProductOf`](@ref), [`GenericProd`] +""" +struct LinearizedProductOf{F} + vector::Vector{F} + length::Int # `length` here is needed for extra safety as we implicitly mutate `vector` in `prod` +end + +function Base.push!(product::LinearizedProductOf{F}, item::F) where {F} + vector = product.vector + vlength = length(vector) + return LinearizedProductOf(push!(vector, item), vlength + 1) +end + +Distributions.support(dist::LinearizedProductOf) = support(first(dist.vector)) + +Base.length(product::LinearizedProductOf) = product.length +Base.eltype(product::LinearizedProductOf) = eltype(first(product.vector)) + +function Base.:(==)(left::LinearizedProductOf, right::LinearizedProductOf) + return (left.length == right.length) && (left.vector == right.vector) +end + +function BayesBase.samplefloattype(product::LinearizedProductOf) + return samplefloattype(first(product.vector)) +end + +Distributions.variate_form(::Type{<:LinearizedProductOf{F}}) where {F} = variate_form(F) +Distributions.variate_form(::LinearizedProductOf{F}) where {F} = variate_form(F) + +Distributions.value_support(::Type{<:LinearizedProductOf{F}}) where {F} = value_support(F) +Distributions.value_support(::LinearizedProductOf{F}) where {F} = value_support(F) + +function Base.show(io::IO, product::LinearizedProductOf{F}) where {F} + return print(io, "LinearizedProductOf(", F, ", length = ", product.length, ")") +end + +function Distributions.logpdf(dist::LinearizedProductOf, x) + return mapreduce( + (d) -> logpdf(d, x), +, view(dist.vector, 1:min(dist.length, length(dist.vector))) + ) +end + +Distributions.pdf(dist::LinearizedProductOf, x) = exp(logpdf(dist, x)) + +# We assume that it is better (really) to preserve the type of the `LinearizedProductOf`, it is just faster for the compiler +function BayesBase.default_prod_rule(::Type{F}, ::Type{LinearizedProductOf{F}}) where {F} + return PreserveTypeProd(LinearizedProductOf{F}) +end +function BayesBase.default_prod_rule(::Type{LinearizedProductOf{F}}, ::Type{F}) where {F} + return PreserveTypeProd(LinearizedProductOf{F}) +end + +function Base.prod( + ::PreserveTypeProd{LinearizedProductOf{F}}, product::LinearizedProductOf{F}, item::F +) where {F} + return push!(product, item) +end + +function Base.prod( + ::PreserveTypeProd{LinearizedProductOf{F}}, item::F, product::LinearizedProductOf{F} +) where {F} + return push!(product, item) +end + +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left::ProductOf{F,F}, right::F +) where {F} + return LinearizedProductOf(F[getleft(left), getright(left), right], 3) +end + +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left::ProductOf{L,R}, right::R +) where {L,R} + return ProductOf(getleft(left), LinearizedProductOf(R[getright(left), right], 2)) +end + +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left::ProductOf{L,R}, right::L +) where {L,R} + return ProductOf(LinearizedProductOf(L[getleft(left), right], 2), getright(left)) +end + +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left::L, right::ProductOf{L,R} +) where {L,R} + return ProductOf(LinearizedProductOf(L[left, getleft(right)], 2), getright(right)) +end + +function Base.prod( + ::GenericProd, ::UnspecifiedProd, ::UnspecifiedProd, left::R, right::ProductOf{L,R} +) where {L,R} + return ProductOf(getleft(right), LinearizedProductOf(R[left, getright(right)], 2)) +end + +function Base.prod( + ::GenericProd, + ::UnspecifiedProd, + ::UnspecifiedProd, + left::ProductOf{L,LinearizedProductOf{R}}, + right::R, +) where {L,R} + return ProductOf(getleft(left), push!(getright(left), right)) +end + +function Base.prod( + ::GenericProd, + ::UnspecifiedProd, + ::UnspecifiedProd, + left::ProductOf{LinearizedProductOf{L},R}, + right::L, +) where {L,R} + return ProductOf(push!(getleft(left), right), getright(left)) +end diff --git a/src/promotion.jl b/src/promotion.jl new file mode 100644 index 0000000..d49e510 --- /dev/null +++ b/src/promotion.jl @@ -0,0 +1,179 @@ +export deep_eltype, + promote_variate_type, + paramfloattype, + promote_paramfloattype, + convert_paramfloattype, + sampletype, + samplefloattype, + promote_sampletype, + promote_samplefloattype + +# Julia does not really like expressions of the form +# map((e) -> convert(T, e), collection) +# because type `T` is inside lambda function +# https://github.com/JuliaLang/julia/issues/15276 +# https://github.com/JuliaLang/julia/issues/47760 +struct PromoteTypeConverter{T,C} + convert::C +end + +PromoteTypeConverter(::Type{T}, convert::C) where {T,C} = PromoteTypeConverter{T,C}(convert) + +(converter::PromoteTypeConverter{T})(something) where {T} = converter.convert(T, something) + +""" + deep_eltype(T) + +Returns: +- `deep_eltype` of `T` if `T` is an `AbstractArray` container +- `T` otherwise + +```jldoctest +julia> deep_eltype(Float64) +Float64 + +julia> deep_eltype(Vector{Float64}) +Float64 + +julia> deep_eltype(Vector{Matrix{Vector{Float64}}}) +Float64 +``` +""" +function deep_eltype end + +deep_eltype(::Type{T}) where {T} = T +deep_eltype(::Type{T}) where {T<:AbstractArray} = deep_eltype(eltype(T)) +deep_eltype(any) = deep_eltype(typeof(any)) + +""" + promote_variate_PromoteTypeConverter(::Type{ <: VariateForm }, distribution_type) + +Promotes (if possible) a `distribution_type` to be of the specified variate form. +""" +function promote_variate_type end + +""" + promote_variate_type(::Type{D}, distribution_type) where { D <: Distribution } + +Promotes (if possible) a `distribution_type` to be of the same variate form as `D`. +""" +function promote_variate_type(::Type{D}, T) where {D<:Distribution} + return promote_variate_type(variate_form(D), T) +end + +""" + paramfloattype(distribution) + +Returns the underlying float type of distribution's parameters. + +See also: [`ExponentialFamily.promote_paramfloattype`](@ref), [`ExponentialFamily.convert_paramfloattype`](@ref) +""" +function paramfloattype(distribution::Distribution) + return promote_type(map(deep_eltype, params(distribution))...) +end +paramfloattype(nt::NamedTuple) = promote_paramfloattype(values(nt)) +paramfloattype(t::Tuple) = promote_paramfloattype(t...) + +# `Bool` is the smallest possible type, should not play any role in the promotion +paramfloattype(::Nothing) = Bool + +""" + promote_paramfloattype(distributions...) + +Promotes `paramfloattype` of the `distributions` to a single type. See also `promote_type`. + +See also: [`ExponentialFamily.paramfloattype`](@ref), [`ExponentialFamily.convert_paramfloattype`](@ref) +""" +function promote_paramfloattype(distributions...) + return promote_type(map(paramfloattype, distributions)...) +end + +""" + convert_paramfloattype(::Type{T}, distribution) + +Converts (if possible) the params float type of the `distribution` to be of type `T`. + +See also: [`ExponentialFamily.paramfloattype`](@ref), [`ExponentialFamily.promote_paramfloattype`](@ref) +""" +function convert_paramfloattype(::Type{T}, distribution::Distribution) where {T} + return automatic_convert_paramfloattype( + distribution_typewrapper(distribution), + map(convert_paramfloattype(T), params(distribution)), + ) +end +function convert_paramfloattype(::Type{T}, collection::NamedTuple) where {T} + return map(convert_paramfloattype(T), collection) +end +function convert_paramfloattype(collection::NamedTuple) + return convert_paramfloattype(paramfloattype(collection), collection) +end +function convert_paramfloattype(::Type{T}) where {T} + return PromoteTypeConverter(T, convert_paramfloattype) +end + +# We attempt to automatically construct a new distribution with a desired paramfloattype +# This function assumes that the constructor `D(...)` accepts the same order of parameters as +# returned from the `params` function. It is the case for distributions from `Distributions.jl` +automatic_convert_paramfloattype(::Type{D}, params) where {D<:Distribution} = D(params...) +function automatic_convert_paramfloattype(::Type{D}, params) where {D} + return error( + "Cannot automatically construct a distribution of type `$D` with params = $(params)" + ) +end + +""" + convert_paramfloattype(::Type{T}, container) + +Converts (if possible) the elements of the `container` to be of type `T`. +""" +function convert_paramfloattype(::Type{T}, container::AbstractArray) where {T} + return convert(AbstractArray{T}, container) +end +convert_paramfloattype(::Type{T}, number::Number) where {T} = convert(T, number) +convert_paramfloattype(::Type, ::Nothing) = nothing + +""" + sampletype(distribution) + +Returns a type of the distribution. By default fallbacks to the `eltype`. + +See also: [`ExponentialFamily.samplefloattype`](@ref), [`ExponentialFamily.promote_sampletype`](@ref), [`ExponentialFamily.promote_samplefloattype`](@ref) +""" +sampletype(distribution) = eltype(distribution) + +function sampletype(distribution::Distribution) + return sampletype(variate_form(typeof(distribution)), distribution) +end +sampletype(::Type{Univariate}, distribution) = eltype(distribution) +sampletype(::Type{Multivariate}, distribution) = Vector{eltype(distribution)} +sampletype(::Type{Matrixvariate}, distribution) = Matrix{eltype(distribution)} + +""" + samplefloattype(distribution) + +Returns a type of the distribution or the underlying float type in case if sample is `Multivariate` or `Matrixvariate`. +By default fallbacks to the `deep_eltype(sampletype(distribution))`. + +See also: [`ExponentialFamily.sampletype`](@ref), [`ExponentialFamily.promote_sampletype`](@ref), [`ExponentialFamily.promote_samplefloattype`](@ref) +""" +samplefloattype(distribution) = deep_eltype(sampletype(distribution)) + +""" + promote_sampletype(distributions...) + +Promotes `sampletype` of the `distributions` to a single type. See also `promote_type`. + +See also: [`ExponentialFamily.sampletype`](@ref), [`ExponentialFamily.samplefloattype`](@ref), [`ExponentialFamily.promote_samplefloattype`](@ref) +""" +promote_sampletype(distributions...) = promote_type(map(sampletype, distributions)...) + +""" + promote_samplefloattype(distributions...) + +Promotes `samplefloattype` of the `distributions` to a single type. See also `promote_type`. + +See also: [`ExponentialFamily.sampletype`](@ref), [`ExponentialFamily.samplefloattype`](@ref), [`ExponentialFamily.promote_sampletype`](@ref) +""" +function promote_samplefloattype(distributions...) + return promote_type(map(samplefloattype, distributions)...) +end \ No newline at end of file diff --git a/src/statsfuns.jl b/src/statsfuns.jl new file mode 100644 index 0000000..8e7daab --- /dev/null +++ b/src/statsfuns.jl @@ -0,0 +1,191 @@ +export mirrorlog, + xtlog, + logmvbeta, + clamplog, + mvtrigamma, + vague, + isproper, + probvec, + weightedmean, + mean_cov, + mean_var, + mean_std, + mean_invcov, + weightedmean_cov, + weightedmean_var, + weightedmean_std, + weightedmean_invcov, + weightedmean_precision, + logpdf_sampling_optimized, + logpdf_optimized, + sampling_optimized, + components, + component, + distribution_typewrapper + +""" + mirrorlog(x) + +Returns `log(1 - x)`. +""" +mirrorlog(x) = log(one(x) - x) + +""" + xtlog(x) + +Returns `x * log(x)`. +""" +xtlog(x) = x * log(x) + +""" + clamplog(x) + +Same as `log` but clamps the input argument `x` to be in the range `tiny <= x <= typemax(x)` such that `log(0)` does not explode. +""" +clamplog(x) = log(clamp(x, tiny, typemax(x))) + +""" + logmvbeta(x) + +Uses the numerically stable algorithm to compute the logarithm of the multivariate beta distribution over with the parameter vector x. +""" +logmvbeta(x) = sum(loggamma, x) - loggamma(sum(x)) + +""" + mvtrigamma(p, x) + +Computes multivariate trigamma function . +""" +mvtrigamma(p, x) = sum(trigamma(x + (one(x) - i) / 2) for i in 1:p) + +""" + vague(distribution_type, [ dims... ]) + +Returns uninformative probability distribution of the given type. +""" +function vague end + +""" + isproper(T, args...) + +Checks if `args...` are compatible with distribution of type `T`. +""" +function isproper end + +function compute_logscale end + +""" + probvec(d) + +Returns the probability vector of the given distribution. +""" +function probvec end + +""" + weightedmean(d) + +Returns the weighted mean of the given distribution. +Alias to `invcov(d) * mean(d)`, but can be specialized +""" +weightedmean(d) = invcov(d) * mean(d) + +""" +Alias for `(mean(d), cov(d))`, but can be specialized. +""" +mean_cov(something) = (mean(something), cov(something)) + +""" +Alias for `(mean(d), var(d))`, but can be specialized. +""" +mean_var(something) = (mean(something), var(something)) + +""" +Alias for `(mean(d), std(d))`, but can be specialized. +""" +mean_std(something) = (mean(something), std(something)) + +""" +Alias for `(mean(d), invcov(d))`, but can be specialized. +""" +mean_invcov(something) = (mean(something), invcov(something)) + +""" +Alias for `mean_invcov(d)`, but can be specialized. +""" +mean_precision(something) = mean_invcov(something) + +""" +Alias for `(weightedmean(d), cov(d))`, but can be specialized. +""" +weightedmean_cov(something) = (weightedmean(something), cov(something)) + +""" +Alias for `(weightedmean(d), var(d))`, but can be specialized. +""" +weightedmean_var(something) = (weightedmean(something), var(something)) + +""" +Alias for `(weightedmean(d), std(d))`, but can be specialized. +""" +weightedmean_std(something) = (weightedmean(something), std(something)) + +""" +Alias for `(weightedmean(d), invcov(d))`, but can be specialized. +""" + +""" +Alias for `weightedmean_invcov(d)`, but can be specialized. +""" +weightedmean_invcov(something) = (weightedmean(something), invcov(something)) +weightedmean_precision(something) = weightedmean_invcov(something) + +""" + logpdf_sampling_optimized(d) + +`logpdf_sample_optimized` function takes as an input a distribution `d` and returns corresponding optimized two versions +for taking `logpdf()` and sampling with `rand!` respectively. Alias for `(logpdf_optimized(d), sampling_optimized(d))`, but can be specialized. +""" +function logpdf_sampling_optimized(something) + return (logpdf_optimized(something), sampling_optimized(something)) +end + +""" + logpdf_optimized(d) + +Returns a version of `d` specifically optimized to call `logpdf(d, x)`. By default returns the same `d`, but can be specialized. +""" +logpdf_optimized(something) = something + +""" + sampling_optimized(d) + +Returns a version of `d` specifically optimized to call `rand` and `rand!`. By default returns the same `d`, but can be specialized. +""" +sampling_optimized(something) = something + +""" + components(d) + +Returns components of a distribution `d` (joint or a mixture). +""" +function components end + +""" + component(d, k) + +Returns `k`-th component of a distribution `d` (joint or a mixture). +""" +function component end + +""" +Strips type parameters from the type of the `distribution`. +""" +distribution_typewrapper(distribution) = generated_distribution_typewrapper(distribution) + +# Returns a wrapper distribution for a `<:Distribution` type, this function uses internals of Julia +# It is not ideal, but is fine for now, if Julia changes it internals such that does not work +# We will need to write the `distribution_typewrapper` method for each support member of exponential family +# e.g. `distribution_typewrapper(::Bernoulli) = Bernoulli` +@generated function generated_distribution_typewrapper(distribution) + return Base.typename(distribution).wrapper +end \ No newline at end of file diff --git a/test/densities/factorizedjoint_tests.jl b/test/densities/factorizedjoint_tests.jl new file mode 100644 index 0000000..736d45c --- /dev/null +++ b/test/densities/factorizedjoint_tests.jl @@ -0,0 +1,42 @@ +@testitem "FactorizedJoint" begin + using Distributions + + vmultipliers = [ + (Normal(),), + (Normal(), Beta(1.0, 1.0)), + (Normal(), Gamma(), MvNormal([0.0, 0.0], [1.0 0.0; 0.0 1.0])), + ] + + @testset "getindex" begin + for multipliers in vmultipliers + product = FactorizedJoint(multipliers) + @test length(product) === length(multipliers) + for i in eachindex(multipliers) + @test product[i] === multipliers[i] + end + end + end + + @testset "entropy" begin + for multipliers in vmultipliers + product = FactorizedJoint(multipliers) + @test entropy(product) ≈ mapreduce(entropy, +, multipliers) + end + end + + @testset "isapprox" begin + @test FactorizedJoint((Normal(),)) ≈ FactorizedJoint((Normal(),)) + @test !(FactorizedJoint((Normal(0, 1),)) ≈ FactorizedJoint((Normal(1, 1),))) + + @test FactorizedJoint((Gamma(1.0, 1.0), Normal(0.0, 1.0))) ≈ + FactorizedJoint((Gamma(1.000001, 1.0), Normal(0.0, 1.0000000001))) atol = 1e-5 + @test !( + FactorizedJoint((Gamma(1.0, 1.0), Normal(0.0, 1.0))) ≈ + FactorizedJoint((Gamma(1.000001, 1.0), Normal(0.0, 5.0000000001))) + ) + @test !( + FactorizedJoint((Gamma(1.0, 2.0), Normal(0.0, 1.0))) ≈ + FactorizedJoint((Gamma(1.000001, 1.0), Normal(0.0, 1.0000000001))) + ) + end +end \ No newline at end of file diff --git a/test/prod_setuptests.jl b/test/prod_setuptests.jl new file mode 100644 index 0000000..7cb8128 --- /dev/null +++ b/test/prod_setuptests.jl @@ -0,0 +1,100 @@ +using BayesBase, Distributions + +import BayesBase: + prod, + default_prod_rule, + ProductOf, + LinearizedProductOf, + getleft, + getright, + UnspecifiedProd, + PreserveTypeProd, + PreserveTypeLeftProd, + PreserveTypeRightProd, + ClosedProd, + GenericProd + +## =========================================================================== +## Tests fixtures + +# An object, which does not specify any prod rules +struct SomeUnknownObject end + +# Two objects that +# - implement `ClosedProd` between each other +# - implement `prod` with `ClosedProd` between each other +# - can be eaily converted between each other +# - can be converted to an `Int` +struct ObjectWithClosedProd1 end +struct ObjectWithClosedProd2 end + +function BayesBase.default_prod_rule( + ::Type{ObjectWithClosedProd1}, ::Type{ObjectWithClosedProd1} +) + return PreserveTypeProd(ObjectWithClosedProd1) +end +function BayesBase.default_prod_rule( + ::Type{ObjectWithClosedProd2}, ::Type{ObjectWithClosedProd2} +) + return PreserveTypeProd(ObjectWithClosedProd2) +end +function BayesBase.default_prod_rule( + ::Type{ObjectWithClosedProd1}, ::Type{ObjectWithClosedProd2} +) + return PreserveTypeProd(ObjectWithClosedProd1) +end +function BayesBase.default_prod_rule( + ::Type{ObjectWithClosedProd2}, ::Type{ObjectWithClosedProd1} +) + return PreserveTypeProd(ObjectWithClosedProd2) +end + +function BayesBase.prod( + ::PreserveTypeProd{ObjectWithClosedProd1}, + ::ObjectWithClosedProd1, + ::ObjectWithClosedProd1, +) + return ObjectWithClosedProd1() +end + +function BayesBase.prod( + ::PreserveTypeProd{ObjectWithClosedProd2}, + ::ObjectWithClosedProd2, + ::ObjectWithClosedProd2, +) + return ObjectWithClosedProd2() +end + +function BayesBase.prod( + ::PreserveTypeProd{ObjectWithClosedProd1}, + ::ObjectWithClosedProd1, + ::ObjectWithClosedProd2, +) + return ObjectWithClosedProd1() +end + +function BayesBase.prod( + ::PreserveTypeProd{ObjectWithClosedProd2}, + ::ObjectWithClosedProd2, + ::ObjectWithClosedProd1, +) + return ObjectWithClosedProd2() +end + +function Base.convert(::Type{ObjectWithClosedProd1}, ::ObjectWithClosedProd2) + return ObjectWithClosedProd1() +end +function Base.convert(::Type{ObjectWithClosedProd2}, ::ObjectWithClosedProd1) + return ObjectWithClosedProd2() +end + +Base.convert(::Type{Int}, ::ObjectWithClosedProd1) = 1 +Base.convert(::Type{Int}, ::ObjectWithClosedProd2) = 2 + +struct ADistributionObject <: ContinuousUnivariateDistribution end + +function BayesBase.prod( + ::PreserveTypeProd{Distribution}, ::ADistributionObject, ::ADistributionObject +) + return ADistributionObject() +end diff --git a/test/prod_tests.jl b/test/prod_tests.jl new file mode 100644 index 0000000..c2b1a2f --- /dev/null +++ b/test/prod_tests.jl @@ -0,0 +1,170 @@ + +@testitem "UnspecifiedProd" begin + include("./prod_setuptests.jl") + + @testset "`default_prod_rule` should return `UnspecifiedProd` for two unknown objects" begin + @test default_prod_rule(SomeUnknownObject, SomeUnknownObject) === UnspecifiedProd() + end + + @testset "`missing` should be ignored with the `UnspecifiedProd`" begin + @test prod(UnspecifiedProd(), missing, SomeUnknownObject()) === SomeUnknownObject() + @test prod(UnspecifiedProd(), SomeUnknownObject(), missing) === SomeUnknownObject() + @test prod(UnspecifiedProd(), missing, missing) === missing + end +end + +@testitem "ClosedProd" begin + include("./prod_setuptests.jl") + + @testset "`missing` should be ignored with the `ClosedProd`" begin + struct SomeObject end + @test prod(ClosedProd(), missing, SomeUnknownObject()) === SomeUnknownObject() + @test prod(ClosedProd(), SomeUnknownObject(), missing) === SomeUnknownObject() + @test prod(ClosedProd(), missing, missing) === missing + end + + @testset "`ClosedProd` for distribution objects should assume `ProdPreserveType(Distribution)`" begin + @test prod(ClosedProd(), ADistributionObject(), ADistributionObject()) isa ADistributionObject + end + +end + +@testitem "PreserveTypeProd" begin + include("./prod_setuptests.jl") + + @testset "`missing` should be ignored with the `PreserveTypeProd`" begin + # Can convert the result of the prod to the desired type + @test prod(PreserveTypeProd(SomeUnknownObject), missing, SomeUnknownObject()) isa SomeUnknownObject + @test prod(PreserveTypeProd(SomeUnknownObject), SomeUnknownObject(), missing) isa SomeUnknownObject + @test prod(PreserveTypeProd(Missing), missing, missing) isa Missing + @test prod(PreserveTypeProd(SomeUnknownObject), missing, missing) isa Missing + end + + @testset "`PreserveTypeLeftProd` should preserve the type of the left argument" begin + @test prod(PreserveTypeLeftProd(), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa ObjectWithClosedProd1 + @test prod(PreserveTypeLeftProd(), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa ObjectWithClosedProd2 + end + + @testset "`PreserveTypeRightProd` should preserve the type of the left argument" begin + @test prod(PreserveTypeRightProd(), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa ObjectWithClosedProd2 + @test prod(PreserveTypeRightProd(), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa ObjectWithClosedProd1 + end + + @testset "`ProdPreserveType(T)` should preserve the desired type of `T`" begin + @test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd1(), ObjectWithClosedProd1()) isa + ObjectWithClosedProd1 + @test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa + ObjectWithClosedProd1 + @test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa + ObjectWithClosedProd1 + @test prod(PreserveTypeProd(ObjectWithClosedProd1), ObjectWithClosedProd2(), ObjectWithClosedProd2()) isa + ObjectWithClosedProd1 + + @test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd1(), ObjectWithClosedProd1()) isa + ObjectWithClosedProd2 + @test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa + ObjectWithClosedProd2 + @test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa + ObjectWithClosedProd2 + @test prod(PreserveTypeProd(ObjectWithClosedProd2), ObjectWithClosedProd2(), ObjectWithClosedProd2()) isa + ObjectWithClosedProd2 + + # The output can be converted to an `Int` (see the fixtures above) + @test prod(PreserveTypeProd(Int), ObjectWithClosedProd1(), ObjectWithClosedProd1()) isa Int + @test prod(PreserveTypeProd(Int), ObjectWithClosedProd1(), ObjectWithClosedProd2()) isa Int + @test prod(PreserveTypeProd(Int), ObjectWithClosedProd2(), ObjectWithClosedProd1()) isa Int + @test prod(PreserveTypeProd(Int), ObjectWithClosedProd2(), ObjectWithClosedProd2()) isa Int + + # The output can not be converted to a `Float` (see the fixtures above) + @test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd1(), ObjectWithClosedProd1()) + @test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd1(), ObjectWithClosedProd2()) + @test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd2(), ObjectWithClosedProd1()) + @test_throws MethodError prod(PreserveTypeProd(Float64), ObjectWithClosedProd2(), ObjectWithClosedProd2()) + end +end + +@testitem "GenericProd" begin + include("./prod_setuptests.jl") + + × = (x, y) -> prod(GenericProd(), x, y) + + @testset "GenericProd should use `default_prod_rule` where possible" begin + + # `SomeUnknownObject` does not implement any prod rule (see the fixtures above) + @test SomeUnknownObject() × SomeUnknownObject() isa ProductOf{SomeUnknownObject, SomeUnknownObject} + @test ObjectWithClosedProd1() × SomeUnknownObject() isa ProductOf{ObjectWithClosedProd1, SomeUnknownObject} + @test SomeUnknownObject() × ObjectWithClosedProd1() isa ProductOf{SomeUnknownObject, ObjectWithClosedProd1} + + @test getleft(ObjectWithClosedProd1() × SomeUnknownObject()) === ObjectWithClosedProd1() + @test getright(ObjectWithClosedProd1() × SomeUnknownObject()) === SomeUnknownObject() + @test getleft(SomeUnknownObject() × ObjectWithClosedProd1()) === SomeUnknownObject() + @test getright(SomeUnknownObject() × ObjectWithClosedProd1()) === ObjectWithClosedProd1() + + # Both `ObjectWithClosedProd1` and `ObjectWithClosedProd2` implement `ClosedProd` as a default (see the fixtures above) + @test ObjectWithClosedProd1() × ObjectWithClosedProd1() isa ObjectWithClosedProd1 + @test ObjectWithClosedProd2() × ObjectWithClosedProd2() isa ObjectWithClosedProd2 + end + + @testset "ProdGeneric should simplify a product tree if closed form product available for leaves" begin + d1 = SomeUnknownObject() + d2 = ObjectWithClosedProd1() + d3 = ObjectWithClosedProd2() + + @test (d1 × d2) × d2 == d1 × d2 + @test (d1 × d3) × d3 == d1 × d3 + @test (d2 × d3) × d3 == d2 + @test (d3 × d2) × d2 == d3 + + @test d1 × (d2 × d2) == d1 × d2 + @test d1 × (d3 × d3) == d1 × d3 + @test d2 × (d3 × d3) == d2 + @test d3 × (d2 × d2) == d3 + + @test (d2 × d1) × d2 == (d2 × d1) + @test (d3 × d1) × d3 == (d3 × d1) + @test (d2 × d2) × d1 == (d2 × d1) + @test (d3 × d3) × d1 == (d3 × d1) + + @test d2 × (d1 × d2) == (d1 × d2) + @test d3 × (d1 × d3) == (d1 × d3) + @test d2 × (d2 × d1) == (d2 × d1) + @test d3 × (d3 × d1) == (d3 × d1) + end + + @testset "ProdGeneric should create a product tree if closed form product is not available" begin + d1 = SomeUnknownObject() + + @test 1.0 × 1 × d1 isa ProductOf{ProductOf{Float64, Int}, SomeUnknownObject} + @test 1 × 1.0 × d1 isa ProductOf{ProductOf{Int, Float64}, SomeUnknownObject} + end + + @testset "ProdGeneric should create a linearised product tree if closed form product is not available, but objects are of the same type" begin + d1 = SomeUnknownObject() + d2 = ObjectWithClosedProd1() + + @test d1 × d1 isa ProductOf{SomeUnknownObject, SomeUnknownObject} + + @testset let product = d1 × d1 × d1 + @test product isa LinearizedProductOf{SomeUnknownObject} + @test length(product) === 3 + + # Test that the next prod rule should preserve the type of the linearized product + @test default_prod_rule(product, d1) isa PreserveTypeProd{LinearizedProductOf{SomeUnknownObject}} + end + + @testset let product = (d1 × d1 × d1) × d1 + @test product isa LinearizedProductOf{SomeUnknownObject} + @test length(product) === 4 + + # Test that the next prod rule should preserve the type of the linearized product + @test default_prod_rule(product, d1) isa PreserveTypeProd{LinearizedProductOf{SomeUnknownObject}} + end + + @test d2 × d1 × d1 × d1 isa ProductOf{ObjectWithClosedProd1, LinearizedProductOf{SomeUnknownObject}} + @test d2 × d1 × d1 × d1 × d1 isa ProductOf{ObjectWithClosedProd1, LinearizedProductOf{SomeUnknownObject}} + + # d2 × (...) × d2 should fold if closed prod is available + @test d2 × d1 × d1 × d1 × d1 × d2 == (d2 × d2) × d1 × d1 × d1 × d1 + @test d2 × d1 × d2 × d1 × d2 × d1 × d1 × d2 == (d2 × d2 × d2 × d2) × d1 × d1 × d1 × d1 + end +end diff --git a/test/promotion_setuptests.jl b/test/promotion_setuptests.jl new file mode 100644 index 0000000..86e987c --- /dev/null +++ b/test/promotion_setuptests.jl @@ -0,0 +1,26 @@ +using BayesBase, LinearAlgebra, Distributions, StableRNGs + +function generate_random_distributions(::Type{V} = Any; seed = abs(rand(Int)), Types = (Float32, Float64)) where {V} + rng = StableRNG(seed) + distributions = [] + + # Add `Univariate` distributions + for T in Types + push!(distributions, Normal(rand(rng, T), rand(rng, T))) + push!(distributions, Beta(rand(rng, T), rand(rng, T))) + push!(distributions, Gamma(rand(rng, T), rand(rng, T))) + end + + # Add `Multivariate` distributions + for T in Types, n in (2, 3) + push!(distributions, MvNormal(rand(rng, T, n))) + end + + # Add `Matrixvariate` distributions + for T in Types, n in (2, 3) + push!(distributions, InverseWishart(5one(T), Matrix(Diagonal(ones(n))))) + push!(distributions, Wishart(5one(T), Matrix(Diagonal(ones(n))))) + end + + return filter((dist) -> variate_form(typeof(dist)) <: V, distributions) +end \ No newline at end of file diff --git a/test/promotion_tests.jl b/test/promotion_tests.jl new file mode 100644 index 0000000..0cadc16 --- /dev/null +++ b/test/promotion_tests.jl @@ -0,0 +1,122 @@ +@testitem "TypeConverter" begin + import BayesBase: PromoteTypeConverter + + for original_T in (Float16, Float32, Float64), target_T in (Float16, Float32, Float64), n in (1, 2, 3) + converter = PromoteTypeConverter(target_T, convert) + + @test typeof(@inferred(converter(rand(original_T)))) === target_T + end +end + +@testitem "convert_paramfloattype" begin + include("./promotion_setuptests.jl") + + for T in (Float32, Float64, BigFloat) + @test @inferred(eltype(convert_paramfloattype(T, [1.0, 1.0]))) === T + @test @inferred(eltype(convert_paramfloattype(T, [1.0 1.0; 1.0 1.0]))) === T + @test @inferred(eltype(convert_paramfloattype(T, 1.0))) === T + + for distribution in generate_random_distributions() + @test @inferred(paramfloattype(convert_paramfloattype(T, distribution))) === T + end + end +end + +@testitem "sampletype" begin + include("./promotion_setuptests.jl") + + for distribution in generate_random_distributions() + sample = rand(distribution) + @test @inferred(sampletype(distribution)) === typeof(sample) + end +end + +@testitem "promote_sampletype" begin + include("./promotion_setuptests.jl") + + combinations = [ + Iterators.product(generate_random_distributions(Univariate), generate_random_distributions(Univariate)), + Iterators.product(generate_random_distributions(Multivariate), generate_random_distributions(Multivariate)), + Iterators.product( + generate_random_distributions(Matrixvariate), + generate_random_distributions(Matrixvariate) + ) + ] + for combination in combinations + for distributions in combination + samples = rand.(distributions) + @static if VERSION >= v"1.8" + @test @inferred(promote_sampletype(distributions...)) === promote_type(typeof.(samples)...) + else + @test promote_sampletype(distributions...) === promote_type(typeof.(samples)...) + end + end + end +end + +@testitem "deep_eltype" begin + include("./promotion_setuptests.jl") + + for type in [Float32, Float64, Complex{Float64}, BigFloat] + @test deep_eltype(type) === type + @test deep_eltype(zero(type)) === type + + vector = zeros(type, 10) + matrix = zeros(type, 10, 10) + vector_of_vectors = [vector, vector] + vector_of_matrices = [matrix, matrix] + matrix_of_vector = [vector vector; vector vector] + matrix_of_matrices = [matrix matrix; matrix matrix] + + @test deep_eltype(vector) === type + @test deep_eltype(matrix) === type + @test deep_eltype(vector_of_vectors) === type + @test deep_eltype(vector_of_matrices) === type + @test deep_eltype(matrix_of_vector) === type + @test deep_eltype(matrix_of_matrices) === type + end +end + +@testitem "samplefloattype" begin + include("./promotion_setuptests.jl") + + for distribution in generate_random_distributions() + sample = rand(distribution) + @test @inferred(samplefloattype(distribution)) === deep_eltype(typeof(sample)) + end +end + +@testitem "promote_samplefloattype" begin + include("./promotion_setuptests.jl") + + combinations = [ + Iterators.product(generate_random_distributions(Univariate), generate_random_distributions(Univariate)), + Iterators.product(generate_random_distributions(Univariate), generate_random_distributions(Matrixvariate)), + Iterators.product(generate_random_distributions(Multivariate), generate_random_distributions(Multivariate)), + Iterators.product( + generate_random_distributions(Multivariate), + generate_random_distributions(Matrixvariate) + ), + Iterators.product( + generate_random_distributions(Matrixvariate), + generate_random_distributions(Matrixvariate) + ), + Iterators.product( + generate_random_distributions(Univariate), + generate_random_distributions(Matrixvariate), + generate_random_distributions(Matrixvariate) + ) + ] + + for combination in combinations + for distributions in combination + samples = rand.(distributions) + @static if VERSION >= v"1.8" + @test @inferred(promote_samplefloattype(distributions...)) === + promote_type(deep_eltype.(typeof.(samples))...) + else + @test promote_samplefloattype(distributions...) === promote_type(deep_eltype.(typeof.(samples))...) + end + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 46be8c2..05eef90 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,12 @@ -using BayesBase -using Test +using Aqua, CpuId, ReTestItems, BayesBase -@testset "BayesBase.jl" begin - # Write your tests here. -end +# `ambiguities = false` - there are quite some ambiguities, but these should be normal and should not be encountered under normal circumstances +# `piracy = false` - we extend/add some of the methods to the objects defined in the Distributions.jl +Aqua.test_all(BayesBase; ambiguities=false, piracy=false) + +runtests( + BayesBase; + nworkers=cpucores(), + nworker_threads=Int(cputhreads() / cpucores()), + memory_threshold=1.0, +) diff --git a/test/statsfuns_tests.jl b/test/statsfuns_tests.jl new file mode 100644 index 0000000..a5a7b01 --- /dev/null +++ b/test/statsfuns_tests.jl @@ -0,0 +1,29 @@ +@testitem "mirrorlog" begin + for T in (Float32, Float64, BigFloat) + foreach(rand(T, 10)) do number + @test mirrorlog(number) ≈ log(one(number) - number) + end + end +end + +@testitem "xtlog" begin + for T in (Float32, Float64, BigFloat) + foreach(rand(T, 10)) do number + @test xtlog(number) ≈ number * log(number) + end + end +end + +@testitem "clamplog" begin + using TinyHugeNumbers + + for T in (Float32, Float64, BigFloat) + foreach(rand(T, 10)) do number + @test clamplog(number + 2tiny) ≈ log(number + 2tiny) + end + + @test clamplog(zero(T)) ≈ log(convert(T, tiny)) + end +end + +