Skip to content

Commit

Permalink
change the definition of fastcholesky for numbers and uniform scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Nov 24, 2023
1 parent 1c08e12 commit df7b7d5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
16 changes: 12 additions & 4 deletions src/FastCholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,21 @@ function fastcholesky(input::AbstractMatrix)
# The `PositiveFactorizations.default_δ` should small enough in majority of the cases
return cholesky(PositiveFactorizations.Positive, Hermitian(input), tol = PositiveFactorizations.default_δ(input))
end

fastcholesky(input::Number) = cholesky(input)
fastcholesky(input::Diagonal) = cholesky(input)
fastcholesky(input::Hermitian) = cholesky(PositiveFactorizations.Positive, input)
fastcholesky(input::Number) = input
fastcholesky!(input::Number) = input

fastcholesky(x::UniformScaling) = sqrt(x.λ) * I
fastcholesky!(x::UniformScaling) = sqrt(x.λ) * I
function fastcholesky(x::UniformScaling)
return error(
"`fastcholesky` is not defined for `UniformScaling`. The shape is not determined."
)
end
function fastcholesky!(x::UniformScaling)
return error(
"`fastcholesky!` is not defined for `UniformScaling`. The shape is not determined."
)
end

function fastcholesky(input::Matrix{<:BlasFloat})
C = fastcholesky!(copy(input))
Expand Down
7 changes: 4 additions & 3 deletions test/fastcholesky_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ end
@test chollogdet(I) zero
@test all(cholinv_logdet(one * I) .≈ (one * I, zero))
@test_throws ArgumentError chollogdet(two * I)
@test fastcholesky(I) I
@test fastcholesky!(I) I
@test_throws ErrorException fastcholesky(I)
@test_throws ErrorException fastcholesky!(I)
end
end

Expand All @@ -73,7 +73,8 @@ end

for Type in SupportedTypes
let number = rand(Type)
@test fastcholesky(number) === number
@test size(fastcholesky(number).L) == (1, 1)
@test all(fastcholesky(number).L .≈ sqrt(number))
@test cholinv(number) inv(number)
@test cholsqrt(number) sqrt(number)
@test chollogdet(number) logdet(number)
Expand Down

0 comments on commit df7b7d5

Please sign in to comment.