diff --git a/Project.toml b/Project.toml index 2db72fbc..97a65aff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.1" +version = "0.15.2" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index a7fe0b92..93a5b089 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -65,13 +65,13 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::CorrBijector, X::AbstractMatrix{<:Real}) +function transform(b::CorrBijector, X) w = cholesky_upper(X) r = _link_chol_lkj(w) return r end -function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) +function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y) U, logJ = _inv_link_chol_lkj(y) K = size(U, 1) for j in 2:(K - 1) @@ -80,8 +80,8 @@ function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y::AbstractMatrix{<: return pd_from_upper(U), logJ end -logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y) -function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) +logabsdetjac(::Inverse{CorrBijector}, Y) = _logabsdetjac_inv_corr(Y) +function logabsdetjac(b::CorrBijector, X) #= It may be more efficient if we can use un-contraint value to prevent call of b It's recommended to directly call @@ -135,7 +135,7 @@ function logabsdetjac(b::VecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) +function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y) U_logJ = _inv_link_chol_lkj(y) # workaround for `Tracker.TrackedTuple` not supporting iteration U, logJ = U_logJ[1], U_logJ[2] @@ -146,7 +146,7 @@ function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y::AbstractVector{< return pd_from_upper(U), logJ end -function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) +function logabsdetjac(::Inverse{VecCorrBijector}, y) return _logabsdetjac_inv_corr(y) end