From 9f311fd65b3f9bf8fbc23f669dbd07b6536d156d Mon Sep 17 00:00:00 2001 From: Mykola Lukashchuk Date: Thu, 24 Oct 2024 13:56:37 +0200 Subject: [PATCH] fix: add prod between inverse wisharts --- src/distributions/wishart_inverse.jl | 6 ++++++ test/distributions/wishart_inverse_tests.jl | 15 ++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/distributions/wishart_inverse.jl b/src/distributions/wishart_inverse.jl index 676c9b0c..6c308999 100644 --- a/src/distributions/wishart_inverse.jl +++ b/src/distributions/wishart_inverse.jl @@ -231,6 +231,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right) end +BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishart}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishart) + return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), convert(InverseWishartFast, right)) +end + function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix) return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x) end diff --git a/test/distributions/wishart_inverse_tests.jl b/test/distributions/wishart_inverse_tests.jl index 7f6ffcc6..3c586004 100644 --- a/test/distributions/wishart_inverse_tests.jl +++ b/test/distributions/wishart_inverse_tests.jl @@ -236,10 +236,10 @@ end import Distributions: InverseWishart for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) - let left = InverseWishart(νleft, Sleft), right = InverseWishartFast(νright, Sright) + let left = InverseWishart(νleft, Sleft), right = InverseWishart(νleft, Sleft), right_fast = convert(InverseWishartFast, right) # Test commutativity of the product - prod_result1 = prod(PreserveTypeProd(Distribution), left, right) - prod_result2 = prod(PreserveTypeProd(Distribution), right, left) + prod_result1 = prod(PreserveTypeProd(Distribution), left, right_fast) + prod_result2 = prod(PreserveTypeProd(Distribution), right_fast, left) @test prod_result1.ν ≈ prod_result2.ν @test prod_result1.S ≈ prod_result2.S @@ -248,12 +248,17 @@ end @test prod_result1 isa InverseWishartFast @test prod_result2 isa InverseWishartFast - # prod prod stays if we convert fisrt and then do product + # prod stays if we convert fisrt and then do product left_fast = convert(InverseWishartFast, left) - prod_fast = prod(ClosedProd(), left_fast, right) + prod_fast = prod(ClosedProd(), left_fast, right_fast) @test prod_fast.ν ≈ prod_result1.ν @test prod_fast.S ≈ prod_result2.S + + # prod for Inverse Wishart is defenied + prod_result_not_fast = prod(PreserveTypeProd(Distribution), left, right) + @test prod_result_not_fast.ν ≈ prod_result1.ν + @test prod_result_not_fast.S ≈ prod_result1.S end end end