diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 5403c99fb..ed3d60fcf 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -29,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" DifferentiationInterfaceDiffractorExt = "Diffractor" -DifferentiationInterfaceEnzymeExt = "Enzyme" +DifferentiationInterfaceEnzymeExt = ["EnzymeCore", "Enzyme"] DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index af437eef1..2ef5364ae 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -3,7 +3,7 @@ module DifferentiationInterfaceEnzymeExt using ADTypes: ADTypes, AutoEnzyme using Base: Fix1 import DifferentiationInterface as DI -using Enzyme: +using EnzymeCore: Active, Annotation, BatchDuplicated, @@ -27,7 +27,8 @@ using Enzyme: ReverseSplitWithPrimal, ReverseWithPrimal, Split, - WithPrimal, + WithPrimal +using Enzyme: autodiff, autodiff_thunk, create_shadows, diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index c18b46712..3d8c238c8 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -42,10 +42,12 @@ duplicated_backends = [ end; @testset "First order" begin + @info "Step 1" test_differentiation( backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING ) + @info "Step 2" test_differentiation( backends[1:3], default_scenarios(; include_normal=false, include_constantified=true); @@ -53,6 +55,7 @@ end; logging=LOGGING, ) + @info "Step 3" test_differentiation( duplicated_backends, default_scenarios(; include_normal=false, include_closurified=true); @@ -75,6 +78,7 @@ test_differentiation( =# @testset "Second order" begin + @info "Step 4" test_differentiation( [ AutoEnzyme(), @@ -87,12 +91,14 @@ test_differentiation( logging=LOGGING, ) + @info "Step 5" test_differentiation( AutoEnzyme(; mode=Enzyme.Forward); excluded=vcat(FIRST_ORDER, [:hessian, :hvp]), logging=LOGGING, ) + @info "Step 6" test_differentiation( AutoEnzyme(; mode=Enzyme.Reverse); excluded=vcat(FIRST_ORDER, [:second_derivative]), @@ -101,6 +107,7 @@ test_differentiation( end @testset "Sparse" begin + @info "Step 7" test_differentiation( MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)), remove_matrix_inputs(sparse_scenarios()); @@ -114,6 +121,7 @@ end DIT.operator_place(s) == :out && DIT.function_place(s) == :out end + @info "Step 8" test_differentiation( [AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)], filtered_static_scenarios;