Skip to content

Commit

Permalink
fix: add prod between inverse wisharts
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Oct 24, 2024
1 parent 2c11b39 commit 9f311fd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
6 changes: 6 additions & 0 deletions src/distributions/wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart,
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right)
end

BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishart}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishart)
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), convert(InverseWishartFast, right))
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
15 changes: 10 additions & 5 deletions test/distributions/wishart_inverse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ end
import Distributions: InverseWishart

for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = InverseWishart(νleft, Sleft), right = InverseWishartFast(νright, Sright)
let left = InverseWishart(νleft, Sleft), right = InverseWishart(νleft, Sleft), right_fast = convert(InverseWishartFast, right)
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right)
prod_result2 = prod(PreserveTypeProd(Distribution), right, left)
prod_result1 = prod(PreserveTypeProd(Distribution), left, right_fast)
prod_result2 = prod(PreserveTypeProd(Distribution), right_fast, left)

@test prod_result1.ν prod_result2.ν
@test prod_result1.S prod_result2.S
Expand All @@ -248,12 +248,17 @@ end
@test prod_result1 isa InverseWishartFast
@test prod_result2 isa InverseWishartFast

# prod prod stays if we convert fisrt and then do product
# prod stays if we convert fisrt and then do product
left_fast = convert(InverseWishartFast, left)
prod_fast = prod(ClosedProd(), left_fast, right)
prod_fast = prod(ClosedProd(), left_fast, right_fast)

@test prod_fast.ν prod_result1.ν
@test prod_fast.S prod_result2.S

# prod for Inverse Wishart is defenied
prod_result_not_fast = prod(PreserveTypeProd(Distribution), left, right)
@test prod_result_not_fast.ν prod_result1.ν
@test prod_result_not_fast.S prod_result1.S
end
end
end

0 comments on commit 9f311fd

Please sign in to comment.