diff --git a/src/distributions/wishart.jl b/src/distributions/wishart.jl index 773166a4..1fcbb238 100644 --- a/src/distributions/wishart.jl +++ b/src/distributions/wishart.jl @@ -203,6 +203,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::WishartFast, rig return WishartFast(df, invV) end +BayesBase.default_prod_rule(::Type{<:Wishart}, ::Type{<:WishartFast}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::Wishart, right::WishartFast) + return prod(PreserveTypeProd(Distribution), convert(WishartFast, left), right) +end + function BayesBase.insupport(ef::ExponentialFamilyDistribution{WishartFast}, x::Matrix) return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x) end diff --git a/src/distributions/wishart_inverse.jl b/src/distributions/wishart_inverse.jl index 4cb80207..676c9b0c 100644 --- a/src/distributions/wishart_inverse.jl +++ b/src/distributions/wishart_inverse.jl @@ -225,6 +225,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishartFa return InverseWishartFast(df, V) end +BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishartFast}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishartFast) + return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right) +end + function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix) return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x) end diff --git a/test/distributions/wishart_tests.jl b/test/distributions/wishart_tests.jl index 33e90f71..efbcfce2 100644 --- a/test/distributions/wishart_tests.jl +++ b/test/distributions/wishart_tests.jl @@ -131,3 +131,32 @@ end end end end + +@testitem "Wishart: prod between Wishart and WishartFast" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: WishartFast + import Distributions: Wishart + + for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) + let left = Wishart(νleft, Sleft), right = WishartFast(νright, Sright) + # Test commutativity of the product + prod_result1 = prod(PreserveTypeProd(Distribution), left, right) + prod_result2 = prod(PreserveTypeProd(Distribution), right, left) + + @test prod_result1.ν ≈ prod_result2.ν + @test prod_result1.invS ≈ prod_result2.invS + + # Test that the product preserves type + @test prod_result1 isa WishartFast + @test prod_result2 isa WishartFast + + # prod the same before conversion + left_fast = convert(WishartFast, left) + prod_fast = prod(ClosedProd(), left_fast, right) + + @test prod_fast.ν ≈ prod_result1.ν + @test prod_fast.invS ≈ prod_result2.invS + end + end +end \ No newline at end of file