diff --git a/src/distributions/wishart.jl b/src/distributions/wishart.jl index 773166a4..1fcbb238 100644 --- a/src/distributions/wishart.jl +++ b/src/distributions/wishart.jl @@ -203,6 +203,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::WishartFast, rig return WishartFast(df, invV) end +BayesBase.default_prod_rule(::Type{<:Wishart}, ::Type{<:WishartFast}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::Wishart, right::WishartFast) + return prod(PreserveTypeProd(Distribution), convert(WishartFast, left), right) +end + function BayesBase.insupport(ef::ExponentialFamilyDistribution{WishartFast}, x::Matrix) return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x) end diff --git a/src/distributions/wishart_inverse.jl b/src/distributions/wishart_inverse.jl index 4cb80207..6c308999 100644 --- a/src/distributions/wishart_inverse.jl +++ b/src/distributions/wishart_inverse.jl @@ -225,6 +225,18 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishartFa return InverseWishartFast(df, V) end +BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishartFast}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishartFast) + return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right) +end + +BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishart}) = PreserveTypeProd(Distribution) + +function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishart) + return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), convert(InverseWishartFast, right)) +end + function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix) return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x) end diff --git a/test/distributions/wishart_inverse_tests.jl b/test/distributions/wishart_inverse_tests.jl index 9d76f41b..3c586004 100644 --- a/test/distributions/wishart_inverse_tests.jl +++ b/test/distributions/wishart_inverse_tests.jl @@ -228,3 +228,37 @@ end end end end + +@testitem "InverseWishart: prod between InverseWishart and InverseWishartFast" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: InverseWishartFast + import Distributions: InverseWishart + + for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) + let left = InverseWishart(νleft, Sleft), right = InverseWishart(νleft, Sleft), right_fast = convert(InverseWishartFast, right) + # Test commutativity of the product + prod_result1 = prod(PreserveTypeProd(Distribution), left, right_fast) + prod_result2 = prod(PreserveTypeProd(Distribution), right_fast, left) + + @test prod_result1.ν ≈ prod_result2.ν + @test prod_result1.S ≈ prod_result2.S + + # Test that the product preserves type + @test prod_result1 isa InverseWishartFast + @test prod_result2 isa InverseWishartFast + + # prod stays if we convert fisrt and then do product + left_fast = convert(InverseWishartFast, left) + prod_fast = prod(ClosedProd(), left_fast, right_fast) + + @test prod_fast.ν ≈ prod_result1.ν + @test prod_fast.S ≈ prod_result2.S + + # prod for Inverse Wishart is defenied + prod_result_not_fast = prod(PreserveTypeProd(Distribution), left, right) + @test prod_result_not_fast.ν ≈ prod_result1.ν + @test prod_result_not_fast.S ≈ prod_result1.S + end + end +end diff --git a/test/distributions/wishart_tests.jl b/test/distributions/wishart_tests.jl index 33e90f71..fca2acbc 100644 --- a/test/distributions/wishart_tests.jl +++ b/test/distributions/wishart_tests.jl @@ -131,3 +131,33 @@ end end end end + +@testitem "Wishart: prod between Wishart and WishartFast" begin + include("distributions_setuptests.jl") + + import ExponentialFamily: WishartFast + import Distributions: Wishart + + for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5) + let left = Wishart(νleft, Sleft), right = WishartFast(νright, Sright) + # Test commutativity of the product + prod_result1 = prod(PreserveTypeProd(Distribution), left, right) + prod_result2 = prod(PreserveTypeProd(Distribution), right, left) + + @test prod_result1.ν ≈ prod_result2.ν + @test prod_result1.invS ≈ prod_result2.invS + + # Test that the product preserves type + @test prod_result1 isa WishartFast + @test prod_result2 isa WishartFast + + # prod stays the same if we convert fisrt and then do product + left_fast = convert(WishartFast, left) + prod_fast = prod(ClosedProd(), left_fast, right) + + @test prod_fast.ν ≈ prod_result1.ν + @test prod_fast.invS ≈ prod_result2.invS + end + end +end +