Skip to content

Commit

Permalink
Symmetric coloring and decompression for Hessians (#272)
Browse files Browse the repository at this point in the history
* Symmetric Hessian coloring

* Bump SMC compat
  • Loading branch information
gdalle authored May 27, 2024
1 parent b262365 commit f70139f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ PolyesterForwardDiff = "0.1.1"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.4.2"
SparseMatrixColorings = "0.3.1"
SparseMatrixColorings = "0.3.2"
Symbolics = "5.27.1"
Tapir = "0.2.4"
Tracker = "0.2.33"
Expand Down
6 changes: 4 additions & 2 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
using SparseMatrixColorings:
GreedyColoringAlgorithm,
color_groups,
decompress_columns!,
decompress_columns,
decompress_columns!,
decompress_rows,
decompress_rows!,
decompress_rows
decompress_symmetric,
decompress_symmetric!

abstract type Extras end

Expand Down
6 changes: 3 additions & 3 deletions DifferentiationInterface/src/sparse/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end
function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
sparsity = col_major(initial_sparsity)
colors = column_coloring(sparsity, coloring_algorithm(backend)) # no star coloring
colors = symmetric_coloring(sparsity, coloring_algorithm(backend))
groups = color_groups(colors)
seeds = map(groups) do group
seed = zero(x)
Expand All @@ -41,7 +41,7 @@ function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtra
hvp!(f, products[k], backend, x, seeds[k], hvp_extras_same)
copyto!(view(compressed, :, k), vec(products[k]))
end
decompress_columns!(hess, sparsity, compressed, colors)
decompress_symmetric!(hess, sparsity, compressed, colors)
return hess
end

Expand All @@ -51,5 +51,5 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) wher
compressed = stack(eachindex(seeds, products); dims=2) do k
vec(hvp(f, backend, x, seeds[k], hvp_extras_same))
end
return decompress_columns(sparsity, compressed, colors)
return decompress_symmetric(sparsity, compressed, colors)
end

0 comments on commit f70139f

Please sign in to comment.