Skip to content

Commit

Permalink
Basis
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed May 13, 2024
1 parent 9871a25 commit 20dd0d0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
44 changes: 36 additions & 8 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ function value_and_jacobian_onearg_aux(
f::F, backend, x::AbstractArray, extras::PushforwardJacobianExtras
) where {F}
pushforward_extras_same = prepare_pushforward_same_point(
f, backend, x, basis(backend, x, 1), extras.pushforward_extras
f,
backend,
x,
basis(backend, x, first(CartesianIndices(x))),
extras.pushforward_extras,
)
y = f(x) # TODO: remove
jac = stack(CartesianIndices(x); dims=2) do j
Expand All @@ -115,7 +119,7 @@ function value_and_jacobian_onearg_aux(
f::F, backend, x::AbstractArray, extras::PullbackJacobianExtras
) where {F}
pullback_extras_same = prepare_pullback_same_point(
f, backend, x, basis(backend, y, 1), extras.pullback_extras
f, backend, x, basis(backend, y, first(CartesianIndices(y))), extras.pullback_extras
)
y = f(x) # TODO: remove
jac = stack(CartesianIndices(y); dims=1) do i
Expand All @@ -140,7 +144,11 @@ function value_and_jacobian_onearg_aux!(
f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras
) where {F}
pushforward_extras_same = prepare_pushforward_same_point(
f, backend, x, basis(backend, x, 1), extras.pushforward_extras
f,
backend,
x,
basis(backend, x, first(CartesianIndices(x))),
extras.pushforward_extras,
)
y = f(x) # TODO: remove
for (k, j) in enumerate(CartesianIndices(x))
Expand All @@ -155,7 +163,7 @@ function value_and_jacobian_onearg_aux!(
f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
) where {F}
pullback_extras_same = prepare_pullback_same_point(
f, backend, x, basis(backend, y, 1), extras.pullback_extras
f, backend, x, basis(backend, y, first(CartesianIndices(y))), extras.pullback_extras
)
y = f(x) # TODO: remove
for (k, i) in enumerate(CartesianIndices(y))
Expand Down Expand Up @@ -198,7 +206,12 @@ function value_and_jacobian_twoarg_aux(
f!::F, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras
) where {F}
pushforward_extras_same = prepare_pushforward_same_point(
f!, y, backend, x, basis(backend, x, 1), extras.pushforward_extras
f!,
y,
backend,
x,
basis(backend, x, first(CartesianIndices(x))),
extras.pushforward_extras,
)
jac = stack(CartesianIndices(x); dims=2) do j
dx_j = basis(backend, x, j)
Expand All @@ -213,7 +226,12 @@ function value_and_jacobian_twoarg_aux(
f!::F, y, backend, x::AbstractArray, extras::PullbackJacobianExtras
) where {F}
pullback_extras_same = prepare_pullback_same_point(
f!, y, backend, x, basis(backend, y, 1), extras.pullback_extras
f!,
y,
backend,
x,
basis(backend, y, first(CartesianIndices(y))),
extras.pullback_extras,
)
jac = stack(CartesianIndices(y); dims=1) do i
dy_i = basis(backend, y, i)
Expand Down Expand Up @@ -244,7 +262,12 @@ function value_and_jacobian_twoarg_aux!(
extras::PushforwardJacobianExtras,
) where {F}
pushforward_extras_same = prepare_pushforward_same_point(
f!, y, backend, x, basis(backend, x, 1), extras.pushforward_extras
f!,
y,
backend,
x,
basis(backend, x, first(CartesianIndices(x))),
extras.pushforward_extras,
)
for (k, j) in enumerate(CartesianIndices(x))
dx_j = basis(backend, x, j)
Expand All @@ -259,7 +282,12 @@ function value_and_jacobian_twoarg_aux!(
f!::F, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
) where {F}
pullback_extras_same = prepare_pullback_same_point(
f!, y, backend, x, basis(backend, y, 1), extras.pullback_extras
f!,
y,
backend,
x,
basis(backend, y, first(CartesianIndices(y))),
extras.pullback_extras,
)
for (k, i) in enumerate(CartesianIndices(y))
dy_i = basis(backend, y, i)
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function hessian(
f::F, backend::SecondOrder, x, extras::HessianExtras=prepare_hessian(f, backend, x)
) where {F}
hvp_extras_same = prepare_hvp_same_point(
f, backend, x, basis(backend, x, 1), extras.hvp_extras
f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras
)
hess = stack(vec(CartesianIndices(x))) do j
hess_col_j = hvp(f, backend, x, basis(backend, x, j), hvp_extras_same)
Expand Down Expand Up @@ -84,7 +84,7 @@ function hessian!(
extras::HessianExtras=prepare_hessian(f, backend, x),
) where {F}
hvp_extras_same = prepare_hvp_same_point(
f, backend, x, basis(backend, x, 1), extras.hvp_extras
f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras
)
for (k, j) in enumerate(CartesianIndices(x))
hess_col_j = reshape(view(hess, :, k), size(x))
Expand Down

0 comments on commit 20dd0d0

Please sign in to comment.