Skip to content

Commit

Permalink
Merge pull request #456 from AayushSabharwal/as/restructure-adjoint
Browse files Browse the repository at this point in the history
feat: add adjoint for `ArrayInterface.restructure`
  • Loading branch information
ChrisRackauckas authored Nov 8, 2024
2 parents 8594f42 + c7216c3 commit 5c0b782
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -24,6 +25,7 @@ ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
ArrayInterfaceCUDAExt = "CUDA"
ArrayInterfaceCUDSSExt = "CUDSS"
ArrayInterfaceChainRulesExt = "ChainRules"
ArrayInterfaceChainRulesCoreExt = "ChainRulesCore"
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
ArrayInterfaceReverseDiffExt = "ReverseDiff"
ArrayInterfaceSparseArraysExt = "SparseArrays"
Expand All @@ -37,6 +39,8 @@ BlockBandedMatrices = "0.13"
CUDA = "5"
CUDSS = "0.2, 0.3"
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10"
ReverseDiff = "1"
Expand All @@ -51,6 +55,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Expand All @@ -66,4 +71,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"]
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays", "ChainRulesTestUtils"]
22 changes: 22 additions & 0 deletions ext/ArrayInterfaceChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module ArrayInterfaceChainRulesCoreExt

import ArrayInterface
import ChainRulesCore
import ChainRulesCore: unthunk, NoTangent, ZeroTangent, ProjectTo, @thunk

function ChainRulesCore.rrule(::typeof(ArrayInterface.restructure), target, src)
projectT = ProjectTo(target)
function restructure_pullback(dt)
dt = unthunk(dt)

= NoTangent()
= ZeroTangent()
= @thunk(projectT(ArrayInterface.restructure(src, dt)))

f̄, t̄, s̄
end

return ArrayInterface.restructure(target, src), restructure_pullback
end

end
7 changes: 7 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
using ArrayInterface, ChainRules, Test
using ComponentArrays, ChainRulesTestUtils, StaticArrays

x = ChainRules.OneElement(3.0, (3, 3), (1:4, 1:4))

@test !ArrayInterface.can_setindex(x)
@test !ArrayInterface.can_setindex(typeof(x))

arr = ComponentArray(a = 1.0, b = [2.0, 3.0], c = (; a = 4.0, b = 5.0), d = SVector{2}(6.0, 7.0))
b = zeros(length(arr))

ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, arr, b)
ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, b, arr)

0 comments on commit 5c0b782

Please sign in to comment.