diff --git a/Project.toml b/Project.toml index b45bed1..cd1f7d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterHandling" uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5" authors = ["Invenia Technical Computing Corporation"] -version = "0.4.2" +version = "0.4.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/parameters_matrix.jl b/src/parameters_matrix.jl index 4332254..9b7e453 100644 --- a/src/parameters_matrix.jl +++ b/src/parameters_matrix.jl @@ -1,18 +1,18 @@ """ - nearest_orthogonal_matrix(X::StridedMatrix) + nearest_orthogonal_matrix(X::AbstractMatrix{<:Union{Real,Complex}}) Project `X` onto the closest orthogonal matrix in Frobenius norm. Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446 """ -@inline function nearest_orthogonal_matrix(X::StridedMatrix{<:Union{Real,Complex}}) +@inline function nearest_orthogonal_matrix(X::AbstractMatrix{<:Union{Real,Complex}}) # Inlining necessary for type inference for some reason. U, _, V = svd(X) return U * V' end """ - orthogonal(X::StridedMatrix{<:Real}) + orthogonal(X::AbstractMatrix{<:Real}) Produce a parameter whose `value` is constrained to be an orthogonal matrix. The argument `X` need not be orthogonal. @@ -22,9 +22,9 @@ Frobenius norm) and is overparametrised as a consequence. Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446 """ -orthogonal(X::StridedMatrix{<:Real}) = Orthogonal(X) +orthogonal(X::AbstractMatrix{<:Real}) = Orthogonal(X) -struct Orthogonal{TX<:StridedMatrix{<:Real}} <: AbstractParameter +struct Orthogonal{TX<:AbstractMatrix{<:Real}} <: AbstractParameter X::TX end @@ -39,14 +39,14 @@ function flatten(::Type{T}, X::Orthogonal) where {T<:Real} end """ - positive_definite(X::StridedMatrix{<:Real}) + positive_definite(X::AbstractMatrix{<:Real}) Produce a parameter whose `value` is constrained to be a positive-definite matrix. The argument `X` needs to be a positive-definite matrix (see https://en.wikipedia.org/wiki/Definite_matrix). The unconstrained parameter is a `LowerTriangular` matrix, stored as a vector. """ -function positive_definite(X::StridedMatrix{<:Real}) +function positive_definite(X::AbstractMatrix{<:Real}) isposdef(X) || throw(ArgumentError("X is not positive-definite")) return PositiveDefinite(tril_to_vec(cholesky(X).L)) end