Skip to content

Commit

Permalink
fix: add product between wishart and wishartfast
Browse files Browse the repository at this point in the history
  • Loading branch information
Nimrais committed Oct 24, 2024
1 parent e404c39 commit 25283dc
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/distributions/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::WishartFast, rig
return WishartFast(df, invV)
end

BayesBase.default_prod_rule(::Type{<:Wishart}, ::Type{<:WishartFast}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::Wishart, right::WishartFast)
return prod(PreserveTypeProd(Distribution), convert(WishartFast, left), right)
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{WishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
6 changes: 6 additions & 0 deletions src/distributions/wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishartFa
return InverseWishartFast(df, V)
end

BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishartFast}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishartFast)
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right)
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
29 changes: 29 additions & 0 deletions test/distributions/wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,32 @@ end
end
end
end

@testitem "Wishart: prod between Wishart and WishartFast" begin
include("distributions_setuptests.jl")

import ExponentialFamily: WishartFast
import Distributions: Wishart

for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = Wishart(νleft, Sleft), right = WishartFast(νright, Sright)
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right)
prod_result2 = prod(PreserveTypeProd(Distribution), right, left)

@test prod_result1.ν prod_result2.ν
@test prod_result1.invS prod_result2.invS

# Test that the product preserves type
@test prod_result1 isa WishartFast
@test prod_result2 isa WishartFast

# prod the same before conversion
left_fast = convert(WishartFast, left)
prod_fast = prod(ClosedProd(), left_fast, right)

@test prod_fast.ν prod_result1.ν
@test prod_fast.invS prod_result2.invS
end
end
end

0 comments on commit 25283dc

Please sign in to comment.