Skip to content

Commit

Permalink
Merge pull request #4 from biaslab/dev-static-arrays-extension
Browse files Browse the repository at this point in the history
Handle static arrays better
  • Loading branch information
bvdmitri authored Oct 4, 2023
2 parents 2d31ddc + 8ca7495 commit 5b32683
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 2 deletions.
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@ version = "1.1.0"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
StaticArraysCoreExt = "StaticArraysCore"

[compat]
PositiveFactorizations = "0.2"
StaticArraysCore = "1"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "StaticArrays"]
test = ["Test", "StaticArrays", "StaticArraysCore"]
19 changes: 19 additions & 0 deletions ext/StaticArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module StaticArraysCoreExt # Should be same name as the file (just like a normal package)

using FastCholesky, PositiveFactorizations, StaticArraysCore, LinearAlgebra

function FastCholesky.fastcholesky(input::StaticArraysCore.StaticArray)
C = cholesky(input, check = false)
f = C.factors
u = C.uplo
c = C.info
if !LinearAlgebra.issuccess(C)
C_ = cholesky(Positive, C, tol = PositiveFactorizations.default_δ(C))
f = typeof(C.factors)(C_.factors)
u = C_.uplo
c = C_.info
end
return Cholesky(f, u, c)
end

end # module
5 changes: 5 additions & 0 deletions src/FastCholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,9 @@ cholinv_logdet(input::UniformScaling) = inv(input), logdet(input)
cholinv_logdet(input::Diagonal) = inv(input), logdet(input)
cholinv_logdet(input::Number) = inv(input), log(abs(input))

# Extensions
@static if !isdefined(Base, :get_extension)
include("../ext/StaticArraysCoreExt.jl")
end

end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using FastCholesky
using Test
using LinearAlgebra
using StaticArrays
using StaticArraysCore

make_rand_diagonal(size::Number) = Diagonal(10rand(size))
make_rand_posdef(size::Number) = collect(make_rand_hermitian(size))
Expand Down Expand Up @@ -45,6 +46,9 @@ end
@test all(cholinv_logdet(input) .≈ (inv(input), logdet(input)))
@test cholsqrt(input) * cholsqrt(input)' sqrt(input) * sqrt(input)'
@test cholsqrt(input) * cholsqrt(input)' input

# Check that we do not lose the static type in the process for example
@test typeof(cholesky(input)) === typeof(fastcholesky(input))
end
end

Expand Down

0 comments on commit 5b32683

Please sign in to comment.