From 47ce50ebaa480c68554ade7fd03f34343aa4fef9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Apr 2023 18:31:33 -0400 Subject: [PATCH 1/3] Testing using LuxTestUtils.jl --- ext/LuxComponentArraysExt.jl | 2 +- test/LocalPreferences.toml | 2 + test/Project.toml | 9 +- test/adapt.jl | 8 +- test/autodiff.jl | 32 --- test/contrib/freeze.jl | 35 ++- test/contrib/map.jl | 88 ++++--- test/contrib/share_parameters.jl | 74 +++--- test/contrib/training.jl | 62 ++--- test/ext/LuxFluxTransformExt.jl | 245 +++++++++--------- test/layers/basic.jl | 220 ++++++++-------- test/layers/containers.jl | 268 ++++++++++--------- test/layers/conv.jl | 432 ++++++++++++++++--------------- test/layers/dropout.jl | 169 ++++++------ test/layers/normalize.jl | 293 +++++++++++---------- test/layers/recurrent.jl | 175 +++++++------ test/nnlib.jl | 22 +- test/runtests.jl | 2 - test/test_utils.jl | 120 ++------- test/utils.jl | 45 +--- 20 files changed, 1115 insertions(+), 1188 deletions(-) create mode 100644 test/LocalPreferences.toml delete mode 100644 test/autodiff.jl diff --git a/ext/LuxComponentArraysExt.jl b/ext/LuxComponentArraysExt.jl index a8cdf0eb2b..288bd9607b 100644 --- a/ext/LuxComponentArraysExt.jl +++ b/ext/LuxComponentArraysExt.jl @@ -54,7 +54,7 @@ function CRC.rrule(::Type{ComponentArray}, nt::NamedTuple) "of shape $(size(res)) & type $(typeof(res))") return nothing end - CA_NT_pullback(Δ::ComponentArray) = (@show Δ; (CRC.NoTangent(), NamedTuple(Δ))) + CA_NT_pullback(Δ::ComponentArray) = CRC.NoTangent(), NamedTuple(Δ) return res, CA_NT_pullback end diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml new file mode 100644 index 0000000000..bfc941cb4a --- /dev/null +++ b/test/LocalPreferences.toml @@ -0,0 +1,2 @@ +[LuxTestUtils] +target_modules = ["Lux", "LuxCore", "LuxLib"] diff --git a/test/Project.toml b/test/Project.toml index d99527b4c6..9f5442ff20 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,23 +1,22 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/test/adapt.jl b/test/adapt.jl index ffb4eb3f69..efa7b3f064 100644 --- a/test/adapt.jl +++ b/test/adapt.jl @@ -1,9 +1,9 @@ using Lux, Functors, Random, Test +import LuxCUDA +import LuxCUDA.CUDA -import CUDA - -if CUDA.functional() - using CUDA # exports CuArray, etc +if LuxCUDA.functional() + using LuxCUDA.CUDA # exports CuArray, etc @info "starting CUDA tests" else @info "CUDA not functional, testing via JLArrays" diff --git a/test/autodiff.jl b/test/autodiff.jl deleted file mode 100644 index 5a751c84a3..0000000000 --- a/test/autodiff.jl +++ /dev/null @@ -1,32 +0,0 @@ -using Lux, ComponentArrays, ReverseDiff, Random, Zygote, Test - -include("test_utils.jl") - -rng = Random.default_rng() -Random.seed!(rng, 0) - -@testset "Gradient Correctness: Dense Chain" begin - c = Chain(Dense(3, 4), Dense(4, 1)) - - x = randn(rng, Float32, 3, 2) - ps, st = Lux.setup(rng, c) - - ps = ps |> ComponentArray - - gs_r = ReverseDiff.gradient(ps -> sum(first(Lux.apply(c, x, ps, st))), ps) - gs_z = Zygote.gradient(ps -> sum(first(Lux.apply(c, x, ps, st))), ps)[1] - gs_fdm = FiniteDifferences.grad(FiniteDifferences.central_fdm(5, 1), - ps -> sum(first(Lux.apply(c, x, ps, st))), ps)[1] - - @test gs_r == gs_z - @test gs_r ≈ gs_fdm - @test gs_z ≈ gs_fdm -end - -@testset "Broadcasting identity custom rrule" begin - x = randn(rng, Float32, 3, 2) - gs_x_1 = Zygote.gradient(x -> sum(identity.(x)), x)[1] - gs_x_2 = Zygote.gradient(sum, x)[1] - - @test gs_x_1 == gs_x_2 -end diff --git a/test/contrib/freeze.jl b/test/contrib/freeze.jl index d4f6554bdc..153e541592 100644 --- a/test/contrib/freeze.jl +++ b/test/contrib/freeze.jl @@ -5,10 +5,10 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "All Parameters Freezing" begin +@testset "$mode: All Parameters Freezing" for (mode, aType, device, ongpu) in MODES @testset "NamedTuple" begin d = Dense(5 => 5) - psd, std = Lux.setup(rng, d) + psd, std = Lux.setup(rng, d) .|> device fd, ps, st = Lux.freeze(d, psd, std, nothing) @test length(keys(ps)) == 0 @@ -16,33 +16,32 @@ Random.seed!(rng, 0) @test sort([keys(st)...]) == [:frozen_params, :states] @test sort([keys(st.frozen_params)...]) == [:bias, :weight] - x = randn(rng, Float32, 5, 1) + x = randn(rng, Float32, 5, 1) |> aType @test d(x, psd, std)[1] == fd(x, ps, st)[1] - run_JET_tests(fd, x, ps, st) - test_gradient_correctness_fdm(x -> sum(fd(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + @jet fd(x, ps, st) + __f = (x, ps) -> sum(first(fd(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "ComponentArray" begin m = Chain(Lux.freeze(Dense(1 => 3, tanh)), Dense(3 => 1)) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device ps_c = ComponentVector(ps) - x = randn(rng, Float32, 1, 2) + x = randn(rng, Float32, 1, 2) |> aType @test m(x, ps, st)[1] == m(x, ps_c, st)[1] - run_JET_tests(m, x, ps_c, st) - # Tracker with empty ComponentArray is broken - # test_gradient_correctness_fdm((x, ps) -> sum(m(x, ps, st)[1]), x, ps_c; atol=1.0f-3, - # rtol=1.0f-3) + @jet m(x, ps_c, st) + # __f = (x, ps) -> sum(first(m(x, ps, st))) + # @eval @test_gradients $__f $x $ps_c atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_tracker=true end end -@testset "Partial Freezing" begin +@testset "$mode: Partial Freezing" for (mode, aType, device, ongpu) in MODES d = Dense(5 => 5) - psd, std = Lux.setup(rng, d) + psd, std = Lux.setup(rng, d) .|> device fd, ps, st = Lux.freeze(d, psd, std, (:weight,)) @test length(keys(ps)) == 1 @@ -51,11 +50,11 @@ end @test sort([keys(st.frozen_params)...]) == [:weight] @test sort([keys(ps)...]) == [:bias] - x = randn(rng, Float32, 5, 1) + x = randn(rng, Float32, 5, 1) |> aType @test d(x, psd, std)[1] == fd(x, ps, st)[1] - run_JET_tests(fd, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(fd(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + @jet fd(x, ps, st) + __f = (x, ps) -> sum(first(fd(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end diff --git a/test/contrib/map.jl b/test/contrib/map.jl index b0500f5d67..707a45769d 100644 --- a/test/contrib/map.jl +++ b/test/contrib/map.jl @@ -1,13 +1,8 @@ using Lux, Random, Setfield, Test -c = Parallel(+; - chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), - dense_3=Dense(5 => 1)) +include("../test_utils.jl") -rng = Random.default_rng() -ps, st = Lux.setup(rng, c) - -function zero_dense_params(l, ps, st, name) +function zero_dense_params_1(l, ps, st, name) if l isa Dense && occursin("model.layers.chain", name) @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) @@ -15,16 +10,7 @@ function zero_dense_params(l, ps, st, name) return l, ps, st end -c_, ps_, st_ = Lux.layer_map(zero_dense_params, c, ps, st) - -@test ps_.chain.dense_1.weight == zeros(3, 2) -@test ps_.chain.dense_1.bias == zeros(3, 1) -@test ps_.chain.dense_2.weight == zeros(5, 3) -@test ps_.chain.dense_2.bias == zeros(5, 1) -@test ps_.dense_3.weight != zeros(1, 5) -@test ps_.dense_3.bias == zeros(1, 1) - -function zero_dense_params(l, ps, st, name) +function zero_dense_params_2(l, ps, st, name) if l isa Dense && occursin("c.layers.chain", name) @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) @@ -32,26 +18,7 @@ function zero_dense_params(l, ps, st, name) return l, ps, st end -c_, ps_, st_ = Lux.@layer_map zero_dense_params c ps st - -@test ps_.chain.dense_1.weight == zeros(3, 2) -@test ps_.chain.dense_1.bias == zeros(3, 1) -@test ps_.chain.dense_2.weight == zeros(5, 3) -@test ps_.chain.dense_2.bias == zeros(5, 1) -@test ps_.dense_3.weight != zeros(1, 5) -@test ps_.dense_3.bias == zeros(1, 1) - -# Custom Layers -- See https://github.com/avik-pal/Lux.jl/issues/187 -struct SimpleCustom{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:dense, :conv)} - dense::L1 - conv::L2 -end - -l = SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2)) - -ps, st = Lux.setup(rng, l) - -function zero_dense_params(l, ps, st, name) +function zero_dense_params_3(l, ps, st, name) if l isa Dense @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) @@ -59,9 +26,46 @@ function zero_dense_params(l, ps, st, name) return l, ps, st end -l_, ps_, st_ = Lux.@layer_map zero_dense_params l ps st +@testset "$mode" for (mode, aType, device, ongpu) in MODES + c = Parallel(+; + chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), + dense_2=Dense(3 => 5)), dense_3=Dense(5 => 1)) + + rng = Random.default_rng() + ps, st = Lux.setup(rng, c) .|> device -@test ps_.dense.weight == zeros(Float32, 2, 3) -@test ps_.dense.bias == zeros(Float32, 2, 1) -@test ps_.conv.weight != zeros(Float32, 3, 3, 2) -@test ps_.conv.bias == zeros(Float32, 1, 2, 1) + c_, ps_, st_ = Lux.layer_map(zero_dense_params_1, c, ps, st) + + @test all(iszero, ps_.chain.dense_1.weight) + @test all(iszero, ps_.chain.dense_1.bias) + @test all(iszero, ps_.chain.dense_2.weight) + @test all(iszero, ps_.chain.dense_2.bias) + @test !all(iszero, ps_.dense_3.weight) + @test all(iszero, ps_.dense_3.bias) + + c_, ps_, st_ = Lux.@layer_map zero_dense_params_2 c ps st + + @test all(iszero, ps_.chain.dense_1.weight) + @test all(iszero, ps_.chain.dense_1.bias) + @test all(iszero, ps_.chain.dense_2.weight) + @test all(iszero, ps_.chain.dense_2.bias) + @test !all(iszero, ps_.dense_3.weight) + @test all(iszero, ps_.dense_3.bias) + + # Custom Layers -- See https://github.com/avik-pal/Lux.jl/issues/187 + struct SimpleCustom{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:dense, :conv)} + dense::L1 + conv::L2 + end + + l = SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2)) + + ps, st = Lux.setup(rng, l) .|> device + + l_, ps_, st_ = Lux.@layer_map zero_dense_params_3 l ps st + + @test all(iszero, ps_.dense.weight) + @test all(iszero, ps_.dense.bias) + @test !all(iszero, ps_.conv.weight) + @test all(iszero, ps_.conv.bias) +end diff --git a/test/contrib/share_parameters.jl b/test/contrib/share_parameters.jl index 333eae8e71..5abd3da401 100644 --- a/test/contrib/share_parameters.jl +++ b/test/contrib/share_parameters.jl @@ -1,55 +1,61 @@ -using Lux, Random, Test +using ComponentArrays, Lux, Random, Test include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -model = Chain(; d1=Dense(2 => 4, tanh), d2=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), - d3=Dense(4 => 2)) +@testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Chain(; d1=Dense(2 => 4, tanh), d2=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), + d3=Dense(4 => 2)) -ps, st = Lux.setup(rng, model) + ps, st = Lux.setup(rng, model) .|> device -sharing = (("d2.l2", "d1"), ("d3", "d2.l1")) + sharing = (("d2.l2", "d1"), ("d3", "d2.l1")) -ps_1 = Lux.share_parameters(ps, sharing) + ps_1 = Lux.share_parameters(ps, sharing) -@test ps_1.d2.l2.weight === ps_1.d1.weight -@test ps_1.d2.l2.bias === ps_1.d1.bias -@test ps_1.d3.weight === ps_1.d2.l1.weight -@test ps_1.d3.bias === ps_1.d2.l1.bias + @test ps_1.d2.l2.weight === ps_1.d1.weight + @test ps_1.d2.l2.bias === ps_1.d1.bias + @test ps_1.d3.weight === ps_1.d2.l1.weight + @test ps_1.d3.bias === ps_1.d2.l1.bias -ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) -ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) + ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> + device + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> + device -ps_2 = Lux.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) + ps_2 = Lux.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) -@test ps_2.d2.l2.weight === ps_new_1.weight === ps_2.d1.weight -@test ps_2.d2.l2.bias === ps_new_1.bias === ps_2.d1.bias -@test ps_2.d3.weight === ps_new_2.weight === ps_2.d2.l1.weight -@test ps_2.d3.bias === ps_new_2.bias === ps_2.d2.l1.bias + @test ps_2.d2.l2.weight === ps_new_1.weight === ps_2.d1.weight + @test ps_2.d2.l2.bias === ps_new_1.bias === ps_2.d1.bias + @test ps_2.d3.weight === ps_new_2.weight === ps_2.d2.l1.weight + @test ps_2.d3.bias === ps_new_2.bias === ps_2.d2.l1.bias -# Mix in ComponentArray -ps_new_ca_1 = ComponentArray(ps_new_1) + # Mix in ComponentArray + ps_new_ca_1 = ComponentArray(ps_new_1) -ps_3 = Lux.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) + ps_3 = Lux.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) -@test ps_3.d2.l2.weight === ps_new_ca_1.weight === ps_3.d1.weight -@test ps_3.d2.l2.bias === ps_new_ca_1.bias === ps_3.d1.bias -@test ps_3.d3.weight === ps_new_2.weight === ps_3.d2.l1.weight -@test ps_3.d3.bias === ps_new_2.bias === ps_3.d2.l1.bias + @test ps_3.d2.l2.weight === ps_new_ca_1.weight === ps_3.d1.weight + @test ps_3.d2.l2.bias === ps_new_ca_1.bias === ps_3.d1.bias + @test ps_3.d3.weight === ps_new_2.weight === ps_3.d2.l1.weight + @test ps_3.d3.bias === ps_new_2.bias === ps_3.d2.l1.bias -# Input Checks -non_disjoint_sharing = (("d2.l2", "d1"), ("d1", "d2.l1")) -@test_throws ArgumentError Lux.share_parameters(ps, non_disjoint_sharing) -@test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_1,)) + # Input Checks + non_disjoint_sharing = (("d2.l2", "d1"), ("d1", "d2.l1")) + @test_throws ArgumentError Lux.share_parameters(ps, non_disjoint_sharing) + @test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_1,)) -# Parameter Structure Mismatch -ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) -ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) + # Parameter Structure Mismatch + ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |> + device + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> + device -@test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) + @test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) -ps_new_ca_1 = ComponentArray(ps_new_1) + ps_new_ca_1 = ComponentArray(ps_new_1) -@test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) + @test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) +end diff --git a/test/contrib/training.jl b/test/contrib/training.jl index 2c783182b0..feb1a621f1 100644 --- a/test/contrib/training.jl +++ b/test/contrib/training.jl @@ -2,50 +2,49 @@ using Lux, Optimisers, Random, Test include("../test_utils.jl") -function _get_TrainState() - rng = MersenneTwister(0) - - model = Lux.Dense(3, 2) - opt = Optimisers.Adam(0.01f0) - - tstate = Lux.Training.TrainState(Lux.replicate(rng), model, opt) - - x = randn(Lux.replicate(rng), Float32, (3, 1)) - - return rng, tstate, model, opt, x -end - function _loss_function(model, ps, st, data) y, st = model(data, ps, st) return sum(y), st, () end -function test_TrainState_constructor() - rng, tstate, model, opt, _ = _get_TrainState() +@testset "$mode: TrainState" for (mode, aType, device, ongpu) in MODES + rng = MersenneTwister(0) + + model = Dense(3, 2) + opt = Adam(0.01f0) + + tstate = Lux.Training.TrainState(Lux.replicate(rng), model, opt; + transform_variables=device) - ps, st = Lux.setup(Lux.replicate(rng), model) + x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType + + ps, st = Lux.setup(Lux.replicate(rng), model) .|> device opt_st = Optimisers.setup(opt, tstate.parameters) - @test tstate.model == model - @test tstate.parameters == ps - @test tstate.states == st - @test isapprox(tstate.optimizer_state, opt_st) + @test check_approx(tstate.model, model) + @test check_approx(tstate.parameters, ps) + @test check_approx(tstate.states, st) + @test check_approx(tstate.optimizer_state, opt_st) @test tstate.step == 0 - - return nothing end -function test_abstract_vjp_interface() - _, tstate, _, _, x = _get_TrainState() +@testset "$mode: AbstractVJP" for (mode, aType, device, ongpu) in MODES + rng = MersenneTwister(0) + + model = Dense(3, 2) + opt = Adam(0.01f0) + + tstate = Lux.Training.TrainState(Lux.replicate(rng), model, opt; + transform_variables=device) - @testset "NotImplemented" begin for vjp_rule in (Lux.Training.EnzymeVJP(), - Lux.Training.YotaVJP()) + x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType + + @testset "NotImplemented $(string(vjp_rule))" for vjp_rule in (Lux.Training.EnzymeVJP(), + Lux.Training.YotaVJP()) @test_throws ArgumentError Lux.Training.compute_gradients(vjp_rule, _loss_function, x, tstate) - end end + end - # Gradient Correctness should be tested in `test/autodiff.jl` and other parts of the - # testing codebase. Here we only test that the API works. for vjp_rule in (Lux.Training.ZygoteVJP(), Lux.Training.TrackerVJP()) grads, _, _, _ = @test_nowarn Lux.Training.compute_gradients(vjp_rule, _loss_function, x, @@ -54,9 +53,4 @@ function test_abstract_vjp_interface() @test tstate_.step == 1 @test tstate != tstate_ end - - return nothing end - -@testset "TrainState" begin test_TrainState_constructor() end -@testset "AbstractVJP" begin test_abstract_vjp_interface() end diff --git a/test/ext/LuxFluxTransformExt.jl b/test/ext/LuxFluxTransformExt.jl index 2b436e7833..9e8387a184 100644 --- a/test/ext/LuxFluxTransformExt.jl +++ b/test/ext/LuxFluxTransformExt.jl @@ -1,79 +1,86 @@ import Flux using Lux, Random, Test -@testset "LuxFluxTransformExt" begin +fdevice(::typeof(cpu)) = Flux.cpu +fdevice(::typeof(gpu)) = Flux.gpu + +include("../test_utils.jl") + +@testset "$mode: LuxFluxTransformExt" for (mode, aType, device, ongpu) in MODES @testset "Containers" begin @testset "Chain" begin - model = Flux.Chain(Flux.Dense(2 => 5), Flux.Dense(5 => 1)) - x = rand(Float32, 2, 1) + model = Flux.Chain(Flux.Dense(2 => 5), Flux.Dense(5 => 1)) |> fdevice(device) + x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (1, 1) end @testset "Maxout" begin - model = Flux.Maxout(() -> Flux.Dense(2 => 5), 4) - x = rand(Float32, 2, 1) + model = Flux.Maxout(() -> Flux.Dense(2 => 5), 4) |> fdevice(device) + x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (5, 1) end @testset "Skip Connection" begin - model = Flux.SkipConnection(Flux.Dense(2 => 2), +) - x = rand(Float32, 2, 1) + model = Flux.SkipConnection(Flux.Dense(2 => 2), +) |> fdevice(device) + x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (2, 1) end @testset "Parallel" begin - model = Flux.Parallel(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) - x = rand(Float32, 2, 1) + model = Flux.Parallel(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> + fdevice(device) + x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (2, 1) end @testset "Pairwise Fusion" begin - model = Flux.PairwiseFusion(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) - x = (rand(Float32, 2, 1), rand(Float32, 2, 1)) + model = Flux.PairwiseFusion(+, Flux.Dense(2 => 2), Flux.Dense(2 => 2)) |> + fdevice(device) + x = (rand(Float32, 2, 1), rand(Float32, 2, 1)) .|> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test all(model(x) .≈ model_lux(x, ps, st)[1]) model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test all(size.(model_lux(x, ps, st)[1]) .== ((2, 1),)) end @@ -81,65 +88,68 @@ using Lux, Random, Test @testset "Linear" begin @testset "Dense" begin for model in [ - Flux.Dense(2 => 4), - Flux.Dense(2 => 4; bias=false), + Flux.Dense(2 => 4) |> fdevice(device), + Flux.Dense(2 => 4; bias=false) |> fdevice(device), ] - x = randn(Float32, 2, 4) + x = randn(Float32, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == size(model(x)) end end - @testset "Scale" begin for model in [Flux.Scale(2), Flux.Scale(2; bias=false)] - x = randn(Float32, 2, 4) + @testset "Scale" begin for model in [ + Flux.Scale(2) |> fdevice(device), + Flux.Scale(2; bias=false) |> fdevice(device), + ] + x = randn(Float32, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == size(model(x)) end end @testset "Bilinear" begin for model in [ - Flux.Bilinear((2, 3) => 5), - Flux.Bilinear((2, 3) => 5; bias=false), + Flux.Bilinear((2, 3) => 5) |> fdevice(device), + Flux.Bilinear((2, 3) => 5; bias=false) |> fdevice(device), ] - x = randn(Float32, 2, 4) - y = randn(Float32, 3, 4) + x = randn(Float32, 2, 4) |> aType + y = randn(Float32, 3, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x, y) ≈ model_lux((x, y), ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux((x, y), ps, st)[1]) == size(model(x, y)) end end @testset "Embedding" begin - model = Flux.Embedding(16 => 4) - x = rand(1:16, 2, 4) + model = Flux.Embedding(16 => 4) |> fdevice(device) + x = rand(1:16, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (4, 2, 4) end @@ -147,55 +157,56 @@ using Lux, Random, Test @testset "Convolutions" begin @testset "Conv" begin - model = Flux.Conv((3, 3), 1 => 2) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.Conv((3, 3), 1 => 2) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] - model = Flux.Conv((3, 3), 1 => 2; pad=Flux.SamePad()) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.Conv((3, 3), 1 => 2; pad=Flux.SamePad()) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "CrossCor" begin - model = Flux.CrossCor((3, 3), 1 => 2) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.CrossCor((3, 3), 1 => 2) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] - model = Flux.CrossCor((3, 3), 1 => 2; pad=Flux.SamePad()) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.CrossCor((3, 3), 1 => 2; pad=Flux.SamePad()) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "ConvTranspose" begin - model = Flux.ConvTranspose((3, 3), 1 => 2) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.ConvTranspose((3, 3), 1 => 2) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] - model = Flux.ConvTranspose((3, 3), 1 => 2; pad=Flux.SamePad()) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.ConvTranspose((3, 3), 1 => 2; pad=Flux.SamePad()) |> + fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -203,61 +214,61 @@ using Lux, Random, Test @testset "Pooling" begin @testset "AdaptiveMaxPooling" begin - model = Flux.AdaptiveMaxPool((2, 2)) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.AdaptiveMaxPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "AdaptiveMeanPooling" begin - model = Flux.AdaptiveMeanPool((2, 2)) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.AdaptiveMeanPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "MaxPooling" begin - model = Flux.MaxPool((2, 2)) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.MaxPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "MeanPooling" begin - model = Flux.MeanPool((2, 2)) - x = rand(Float32, 6, 6, 1, 4) + model = Flux.MeanPool((2, 2)) |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "GlobalMaxPooling" begin - model = Flux.GlobalMaxPool() - x = rand(Float32, 6, 6, 1, 4) + model = Flux.GlobalMaxPool() |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "GlobalMeanPooling" begin - model = Flux.GlobalMeanPool() - x = rand(Float32, 6, 6, 1, 4) + model = Flux.GlobalMeanPool() |> fdevice(device) + x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -265,22 +276,22 @@ using Lux, Random, Test @testset "Upsampling" begin @testset "Upsample" begin - model = Flux.Upsample(5) - x = rand(Float32, 2, 2, 2, 1) + model = Flux.Upsample(5) |> fdevice(device) + x = rand(Float32, 2, 2, 2, 1) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (10, 10, 2, 1) @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "PixelShuffle" begin - model = Flux.PixelShuffle(2) - x = randn(Float32, 2, 2, 4, 1) + model = Flux.PixelShuffle(2) |> fdevice(device) + x = randn(Float32, 2, 2, 4, 1) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (4, 4, 1, 1) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -291,31 +302,31 @@ using Lux, Random, Test # @test_throws Lux.FluxModelConversionError transform(Flux.RNN(2 => 2)) @testset "RNNCell" begin - model = Flux.RNNCell(2 => 3) - x = rand(Float32, 2, 4) + model = Flux.RNNCell(2 => 3) |> fdevice(device) + x = rand(Float32, 2, 4) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @testset "LSTMCell" begin - model = Flux.LSTMCell(2 => 3) - x = rand(Float32, 2, 4) + model = Flux.LSTMCell(2 => 3) |> fdevice(device) + x = rand(Float32, 2, 4) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @testset "GRUCell" begin - model = Flux.GRUCell(2 => 3) - x = rand(Float32, 2, 4) + model = Flux.GRUCell(2 => 3) |> fdevice(device) + x = rand(Float32, 2, 4) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @@ -323,60 +334,60 @@ using Lux, Random, Test @testset "Normalize" begin @testset "BatchNorm" begin - model = Flux.BatchNorm(2) - x = randn(Float32, 2, 4) + model = Flux.BatchNorm(2) |> fdevice(device) + x = randn(Float32, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] - x = randn(Float32, 2, 2, 2, 1) + x = randn(Float32, 2, 2, 2, 1) |> aType @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model; preserve_ps_st=true, force_preserve=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "GroupNorm" begin - model = Flux.GroupNorm(4, 2) - x = randn(Float32, 2, 2, 4, 1) + model = Flux.GroupNorm(4, 2) |> fdevice(device) + x = randn(Float32, 2, 2, 4, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model; preserve_ps_st=true, force_preserve=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "LayerNorm" begin - model = Flux.LayerNorm(4) - x = randn(Float32, 4, 4, 4, 1) + model = Flux.LayerNorm(4) |> fdevice(device) + x = randn(Float32, 4, 4, 4, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] end @testset "InstanceNorm" begin - model = Flux.InstanceNorm(4) - x = randn(Float32, 4, 4, 4, 1) + model = Flux.InstanceNorm(4) |> fdevice(device) + x = randn(Float32, 4, 4, 4, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) + ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -386,13 +397,13 @@ using Lux, Random, Test @testset "Dropout" begin model = transform(Flux.Dropout(0.5f0)) - x = randn(Float32, 2, 4) - ps, st = Lux.setup(Random.default_rng(), model) + x = randn(Float32, 2, 4) |> aType + ps, st = Lux.setup(Random.default_rng(), model) .|> device @test size(model(x, ps, st)[1]) == size(x) - x = randn(Float32, 2, 3, 4) - ps, st = Lux.setup(Random.default_rng(), model) + x = randn(Float32, 2, 3, 4) |> aType + ps, st = Lux.setup(Random.default_rng(), model) .|> device @test size(model(x, ps, st)[1]) == size(x) end @@ -400,13 +411,13 @@ using Lux, Random, Test @testset "AlphaDropout" begin model = transform(Flux.AlphaDropout(0.5)) - x = randn(Float32, 2, 4) - ps, st = Lux.setup(Random.default_rng(), model) + x = randn(Float32, 2, 4) |> aType + ps, st = Lux.setup(Random.default_rng(), model) .|> device @test size(model(x, ps, st)[1]) == size(x) - x = randn(Float32, 2, 4, 3) - ps, st = Lux.setup(Random.default_rng(), model) + x = randn(Float32, 2, 4, 3) |> aType + ps, st = Lux.setup(Random.default_rng(), model) .|> device @test size(model(x, ps, st)[1]) == size(x) end @@ -422,11 +433,11 @@ using Lux, Random, Test (c::CustomFluxLayer)(x) = c.weight .* x .+ c.bias - c = CustomFluxLayer(randn(10), randn(10)) - x = randn(10) + c = CustomFluxLayer(randn(10), randn(10)) |> fdevice(device) + x = randn(10) |> aType c_lux = transform(c) - ps, st = Lux.setup(Random.default_rng(), c_lux) + ps, st = Lux.setup(Random.default_rng(), c_lux) .|> device @test c(x) ≈ c_lux(x, ps, st)[1] end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c6341b9872..2ecc32dd5b 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -5,93 +5,99 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "Miscellaneous Layers" begin +@testset "$mode: Miscellaneous Layers" for (mode, aType, device, ongpu) in MODES @testset "Reshape Layer" begin layer = ReshapeLayer((2, 3)) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 6, 3) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 3) |> aType @test size(layer(x, ps, st)[1]) == (2, 3, 3) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end @testset "Flatten Layer" begin layer = FlattenLayer() display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 6, 3, 2) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 3, 2) |> aType @test size(layer(x, ps, st)[1]) == (18, 2) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end @testset "NoOpLayer" begin layer = NoOpLayer() display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device x = (x=2, b=5) # Something totally arbitrary @test layer(x, ps, st)[1] == x - run_JET_tests(layer, x, ps, st) - x = randn(rng, 6, 3) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + @jet layer(x, ps, st) + + x = randn(rng, 6, 3) |> aType + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end @testset "SelectDim Layer" begin layer = SelectDim(3, 1) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 6, 4, 3, 2) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 4, 3, 2) |> aType @test size(layer(x, ps, st)[1]) == (6, 4, 2) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end @testset "WrappedFunction" begin layer = WrappedFunction(x -> x .* x) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 6, 4, 3, 2) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 4, 3, 2) |> aType @test layer(x, ps, st)[1] == x .* x - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end @testset "ActivationFunction" begin layer = ActivationFunction(tanh) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 6, 4, 3, 2) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 6, 4, 3, 2) |> aType @test layer(x, ps, st)[1] == tanh.(x) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end end -@testset "Dense" begin +@testset "$mode: Dense" for (mode, aType, device, ongpu) in MODES @testset "constructors" begin layer = Dense(10, 100) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (100, 10) @test size(ps.bias) == (100, 1) @test layer.activation == identity layer = Dense(10, 100, relu; use_bias=false) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test !haskey(ps, :bias) @test layer.activation == relu @@ -115,28 +121,30 @@ end @testset "zeros" begin @test begin layer = Dense(10, 1, identity; init_weight=ones) - first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) - end == 10 * ones(1, 1) + first(Lux.apply(layer, ones(10, 1) |> aType, device.(Lux.setup(rng, layer))...)) + end == 10 * aType(ones(1, 1)) @test begin layer = Dense(10, 1, identity; init_weight=ones) - first(Lux.apply(layer, ones(10, 2), Lux.setup(rng, layer)...)) - end == 10 * ones(1, 2) + first(Lux.apply(layer, ones(10, 2) |> aType, device.(Lux.setup(rng, layer))...)) + end == 10 * aType(ones(1, 2)) @test begin layer = Dense(10, 2, identity; init_weight=ones) - first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) - end == 10 * ones(2, 1) + first(Lux.apply(layer, ones(10, 1) |> aType, device.(Lux.setup(rng, layer))...)) + end == 10 * aType(ones(2, 1)) @test begin layer = Dense(10, 2, identity; init_weight=ones) - first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) - end == [10 20; 10 20] + first(Lux.apply(layer, aType([ones(10, 1) 2 * ones(10, 1)]), + device.(Lux.setup(rng, layer))...)) + end == aType([10 20; 10 20]) @test begin - layer = Dense(10, 2, identity; init_weight=ones, bias=false) - first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) - end == [10 20; 10 20] + layer = Dense(10, 2, identity; init_weight=ones, use_bias=false) + first(Lux.apply(layer, aType([ones(10, 1) 2 * ones(10, 1)]), + device.(Lux.setup(rng, layer))...)) + end == aType([10 20; 10 20]) end # Deprecated Functionality (Remove in v0.5) @@ -147,17 +155,17 @@ end end end -@testset "Scale" begin +@testset "$mode: Scale" for (mode, aType, device, ongpu) in MODES @testset "constructors" begin layer = Scale(10, 100) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (10, 100) @test size(ps.bias) == (10, 100) @test layer.activation == identity layer = Scale(10, 100, relu; use_bias=false) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test !haskey(ps, :bias) @test layer.activation == relu @@ -172,32 +180,32 @@ end @testset "dimensions" begin layer = Scale(10, 5) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - @test size(first(Lux.apply(layer, randn(10), ps, st))) == (10, 5) - @test size(first(Lux.apply(layer, randn(10, 5, 2), ps, st))) == (10, 5, 2) + @test size(first(Lux.apply(layer, randn(10) |> aType, ps, st))) == (10, 5) + @test size(first(Lux.apply(layer, randn(10, 5, 2) |> aType, ps, st))) == (10, 5, 2) end @testset "zeros" begin @test begin layer = Scale(10, 1, identity; init_weight=ones) - first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) - end == ones(10, 1) + first(Lux.apply(layer, ones(10, 1) |> aType, device.(Lux.setup(rng, layer))...)) + end == aType(ones(10, 1)) @test begin layer = Scale(10, 1, identity; init_weight=ones) - first(Lux.apply(layer, ones(10, 2), Lux.setup(rng, layer)...)) - end == ones(10, 2) + first(Lux.apply(layer, ones(10, 2) |> aType, device.(Lux.setup(rng, layer))...)) + end == aType(ones(10, 2)) @test begin layer = Scale(2, identity; init_weight=ones, init_bias=ones) - first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...)) - end == [2.0 3.0; 4.0 5.0] + first(Lux.apply(layer, [1 2; 3 4] |> aType, device.(Lux.setup(rng, layer))...)) + end == aType([2.0 3.0; 4.0 5.0]) @test begin layer = Scale(2, tanh; bias=false, init_weight=zeros) - first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...)) - end == zeros(2, 2) + first(Lux.apply(layer, [1 2; 3 4] |> aType, device.(Lux.setup(rng, layer))...)) + end == aType(zeros(2, 2)) end # Deprecated Functionality (Remove in v0.5) @@ -208,7 +216,7 @@ end end end -@testset "Bilinear" begin +@testset "$mode: Bilinear" for (mode, aType, device, ongpu) in MODES @testset "SkipConnection recombinator" begin d = Dense(2 => 2) display(d) @@ -216,14 +224,14 @@ end display(b) layer = SkipConnection(d, b) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, Float32, 2, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, Float32, 2, 1) |> aType @test size(layer(x, ps, st)[1]) == (3, 1) - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu d = Dense(2 => 2) display(d) @@ -231,61 +239,61 @@ end display(b) layer = SkipConnection(d, b) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, Float32, 2, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, Float32, 2, 1) |> aType @test size(layer(x, ps, st)[1]) == (3, 1) - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "Two-streams zero sum" begin - x = zeros(Float32, 2, 1) - y = zeros(Float32, 1, 1) + x = zeros(Float32, 2, 1) |> aType + y = zeros(Float32, 1, 1) |> aType layer = Bilinear((2, 1) => 3) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer((x, y), ps, st)[1]) == (3, 1) @test sum(abs2, layer((x, y), ps, st)[1]) == 0.0f0 - @inferred layer((x, y), ps, st) - run_JET_tests(layer, (x, y), ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), (x, y), ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer((x, y), ps, st) + __f = (x, y, ps) -> sum(first(layer((x, y), ps, st))) + @eval @test_gradients $__f $x $y $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "Inner interactions" begin - x = randn(Float32, 2, 1) + x = randn(Float32, 2, 1) |> aType layer = Bilinear((2, 2) => 3) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == (3, 1) - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) - x = randn(Float32, 2, 1) + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + x = randn(Float32, 2, 1) |> aType layer = Bilinear(2 => 3) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == (3, 1) - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end end -@testset "Embedding" begin +@testset "$mode: Embedding" for (mode, aType, device, ongpu) in MODES vocab_size, embed_size = 10, 4 layer = Embedding(vocab_size => embed_size) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (embed_size, vocab_size) @@ -293,26 +301,20 @@ end y, st_ = layer(x, ps, st) @test size(layer(x, ps, st)[1]) == (embed_size,) @test y == ps.weight[:, x] - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(ps -> sum(layer(x, ps, st)[1]), ps; atol=1.0f-3, - rtol=1.0f-3) - x = rand(1:vocab_size, 3) + @jet layer(x, ps, st) + + x = rand(1:vocab_size, 3) |> aType y, st_ = layer(x, ps, st) - @test y isa Matrix{Float32} + @test y isa aType{Float32} @test y == ps.weight[:, x] - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(ps -> sum(layer(x, ps, st)[1]), ps; atol=1.0f-3, - rtol=1.0f-3) - x = rand(1:vocab_size, 3, 4) + @jet layer(x, ps, st) + + x = rand(1:vocab_size, 3, 4) |> aType y, st_ = layer(x, ps, st) - @test y isa Array{Float32, 3} + @test y isa aType{Float32, 3} @test size(y) == (embed_size, 3, 4) - @inferred layer(x, ps, st) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(ps -> sum(layer(x, ps, st)[1]), ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) end diff --git a/test/layers/containers.jl b/test/layers/containers.jl index 3042e9e84d..312ca880ad 100644 --- a/test/layers/containers.jl +++ b/test/layers/containers.jl @@ -5,88 +5,95 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "SkipConnection" begin +@testset "$mode: SkipConnection" for (mode, aType, device, ongpu) in MODES @testset "zero sum" begin layer = SkipConnection(WrappedFunction(zero), (a, b) -> a .+ b) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 10, 10, 10) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 10, 10, 10) |> aType @test layer(x, ps, st)[1] == x - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3, reversediff_broken=true) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 reverse_diff_broken=true gpu_testing=$ongpu end @testset "concat size" begin layer = SkipConnection(Dense(10, 10), (a, b) -> hcat(a, b)) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 2) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 2) |> aType @test size(layer(x, ps, st)[1]) == (10, 4) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end end -@testset "Parallel" begin +@testset "$mode: Parallel" for (mode, aType, device, ongpu) in MODES @testset "zero sum" begin layer = Parallel(+, WrappedFunction(zero), NoOpLayer()) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 10, 10, 10) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 10, 10, 10) |> aType @test layer(x, ps, st)[1] == x - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3, reversediff_broken=true) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 reverse_diff_broken=true gpu_testing=$ongpu end @testset "concat size" begin layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 2) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(rng, 10, 2) |> aType @test size(layer(x, ps, st)[1]) == (10, 4) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == (10, 4) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "vararg input" begin layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) display(layer) - ps, st = Lux.setup(rng, layer) - x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) + ps, st = Lux.setup(rng, layer) .|> device + x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) .|> aType @test size(layer(x, ps, st)[1]) == (2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "named layers" begin layer = Parallel(+; d102=Dense(10, 2), d52=Dense(5, 2), d42=Dense(4, 2)) display(layer) - ps, st = Lux.setup(rng, layer) - x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) + ps, st = Lux.setup(rng, layer) .|> device + x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) .|> aType @test size(layer(x, ps, st)[1]) == (2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "connection is called once" begin @@ -94,10 +101,10 @@ end f_cnt = (x...) -> (CNT[] += 1; +(x...)) layer = Parallel(f_cnt, WrappedFunction(sin), WrappedFunction(cos), WrappedFunction(tan)) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device Lux.apply(layer, 1, ps, st) @test CNT[] == 1 - run_JET_tests(layer, 1, ps, st) + @jet layer(1, ps, st) Lux.apply(layer, (1, 2, 3), ps, st) @test CNT[] == 2 layer = Parallel(f_cnt, WrappedFunction(sin)) @@ -117,48 +124,49 @@ end Base.:*(a::AbstractArray, b::Input) = a * b.x par = Parallel(+, L1(), L1()) - ps, st = Lux.setup(rng, par) + ps, st = Lux.setup(rng, par) .|> device - ip = Input(rand(Float32, 3, 3)) - ip2 = Input(rand(Float32, 3, 3)) + ip = Input(rand(Float32, 3, 3) |> aType) + ip2 = Input(rand(Float32, 3, 3) |> aType) - @test par(ip, ps, st)[1] ≈ - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + - par.layers[2](ip.x, ps.layer_2, st.layer_2)[1] - @test par((ip, ip2), ps, st)[1] ≈ - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + - par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1] + @test check_approx(par(ip, ps, st)[1], + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip.x, ps.layer_2, st.layer_2)[1]) + @test check_approx(par((ip, ip2), ps, st)[1], + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1]) gs = Zygote.gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) gs_reg = Zygote.gradient(ps, ip, ip2) do p, x, y return sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + par.layers[2](y.x, p.layer_2, st.layer_2)[1]) end - @test gs[1] ≈ gs_reg[1] - @test gs[2].x ≈ gs_reg[2].x - @test gs[3].x ≈ gs_reg[3].x + @test check_approx(gs[1], gs_reg[1]) + @test check_approx(gs[2].x, gs_reg[2].x) + @test check_approx(gs[3].x, gs_reg[3].x) end end -@testset "PairwiseFusion" begin - x = (rand(Float32, 1, 10), rand(Float32, 30, 10), rand(Float32, 10, 10)) +@testset "$mode: PairwiseFusion" for (mode, aType, device, ongpu) in MODES + x = (rand(Float32, 1, 10), rand(Float32, 30, 10), rand(Float32, 10, 10)) .|> aType layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device y, _ = layer(x, ps, st) @test size(y) == (10, 10) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = PairwiseFusion(+; d1=Dense(1, 30), d2=Dense(30, 10)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device y, _ = layer(x, ps, st) @test size(y) == (10, 10) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + @jet layer(x, ps, st) + __f = (x1, x2, x3, ps) -> sum(first(layer((x1, x2, x3), ps, st))) + @eval @test_gradients $__f $(x[1]) $(x[2]) $(x[3]) $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu x = rand(1, 10) layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1)) @@ -166,36 +174,38 @@ end ps, st = Lux.setup(rng, layer) y, _ = layer(x, ps, st) @test size(y) == (1, 10) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = PairwiseFusion(vcat, WrappedFunction(x -> x .+ 1), WrappedFunction(x -> x .+ 2), WrappedFunction(x -> x .^ 3)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer((2, 10, 20, 40), ps, st)[1] == [125, 1728, 8000, 40] layer = PairwiseFusion(vcat, WrappedFunction(x -> x .+ 1), WrappedFunction(x -> x .+ 2), WrappedFunction(x -> x .^ 3)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(7, ps, st)[1] == [1000, 729, 343, 7] end -@testset "BranchLayer" begin +@testset "$mode: BranchLayer" for (mode, aType, device, ongpu) in MODES layer = BranchLayer(Dense(10, 10), Dense(10, 10)) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType (y1, y2), _ = layer(x, ps, st) @test size(y1) == (10, 1) @test size(y2) == (10, 1) @test y1 == layer.layers.layer_1(x, ps.layer_1, st.layer_1)[1] @test y2 == layer.layers.layer_2(x, ps.layer_2, st.layer_2)[1] - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(sum, layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(sum, first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = BranchLayer(; d1=Dense(10, 10), d2=Dense(10, 10)) display(layer) @@ -206,121 +216,127 @@ end @test size(y2) == (10, 1) @test y1 == layer.layers.d1(x, ps.d1, st.d1)[1] @test y2 == layer.layers.d2(x, ps.d2, st.d2)[1] - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(sum, layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(sum, first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end -@testset "Chain" begin +@testset "$mode: Chain" for (mode, aType, device, ongpu) in MODES layer = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh), Dense(2 => 1)) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (1, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (1, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) display(layer) layer = layer[1:2] - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) display(layer) layer = layer[begin:(end - 1)] - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) display(layer) layer = layer[1] - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (5, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu @test_throws ArgumentError Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1), d2=Dense(2 => 1), disable_optimizations=false) end -@testset "Maxout" begin +@testset "$mode: Maxout" for (mode, aType, device, ongpu) in MODES @testset "constructor" begin layer = Maxout(() -> NoOpLayer(), 4) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(rng, Float32, 10, 1) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(rng, Float32, 10, 1) |> aType @test layer(x, ps, st)[1] == x - run_JET_tests(layer, x, ps, st) - @inferred layer(x, ps, st) + + @jet layer(x, ps, st) end @testset "simple alternatives" begin layer = Maxout(NoOpLayer(), WrappedFunction(x -> 2x), WrappedFunction(x -> 0.5x)) display(layer) - ps, st = Lux.setup(rng, layer) - x = Float32.(collect(1:40)) + ps, st = Lux.setup(rng, layer) .|> device + x = Float32.(collect(1:40)) |> aType @test layer(x, ps, st)[1] == 2 .* x - run_JET_tests(layer, x, ps, st) - @inferred layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "complex alternatives" begin - layer = Maxout(WrappedFunction(x -> [0.5; 0.1] * x), - WrappedFunction(x -> [0.2; 0.7] * x)) + layer = Maxout(WrappedFunction(x -> aType([0.5; 0.1]) * x), + WrappedFunction(x -> aType([0.2; 0.7]) * x)) display(layer) - ps, st = Lux.setup(rng, layer) - x = [3.0 2.0] - y = [0.5, 0.7] .* x + ps, st = Lux.setup(rng, layer) .|> device + x = [3.0 2.0] |> aType + y = aType([0.5, 0.7]) .* x @test layer(x, ps, st)[1] == y - run_JET_tests(layer, x, ps, st) - @inferred layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "params" begin layer = Maxout(() -> Dense(2, 4), 4) display(layer) - ps, st = Lux.setup(rng, layer) - x = [10.0f0 3.0f0]' + ps, st = Lux.setup(rng, layer) .|> device + x = [10.0f0 3.0f0]' |> aType @test Lux.parameterlength(layer) == sum(Lux.parameterlength.(values(layer.layers))) @test size(layer(x, ps, st)[1]) == (4, 1) - run_JET_tests(layer, x, ps, st) - @inferred layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-1, rtol=1.0f-1) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-1 rtol=1.0f-1 gpu_testing=$ongpu end end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index f5ed04f3dd..628d2b39ec 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -5,114 +5,117 @@ Random.seed!(rng, 0) include("../test_utils.jl") -@testset "Pooling" begin - x = randn(rng, Float32, 10, 10, 3, 2) - y = randn(rng, Float32, 20, 20, 3, 2) +@testset "$mode: Pooling" for (mode, aType, device, ongpu) in MODES + x = randn(rng, Float32, 10, 10, 3, 2) |> aType + y = randn(rng, Float32, 20, 20, 3, 2) |> aType layer = AdaptiveMaxPool((5, 5)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) layer = AdaptiveMeanPool((5, 5)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) layer = AdaptiveMaxPool((10, 5)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(y, ps, st)[1] == maxpool(y, PoolDims(y, (2, 4))) - run_JET_tests(layer, x, ps, st) + @jet layer(y, ps, st) layer = AdaptiveMeanPool((10, 5)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(y, ps, st)[1] == meanpool(y, PoolDims(y, (2, 4))) - run_JET_tests(layer, x, ps, st) + @jet layer(y, ps, st) layer = GlobalMaxPool() display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) layer = GlobalMeanPool() display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) layer = MaxPool((2, 2)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) layer = MeanPool((2, 2)) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) - x = ones(Float32, (k .+ 3)..., 1, 1) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType layer = ltype(k; pad=Lux.SamePad()) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) end end -@testset "CNN" begin +@testset "$mode: CNN" for (mode, aType, device, ongpu) in MODES @testset "Grouped Conv" begin - x = rand(rng, Float32, 4, 6, 1) + x = rand(rng, Float32, 4, 6, 1) |> aType layer = Conv((3,), 6 => 2; groups=2) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (3, 3, 2) @test size(layer(x, ps, st)[1]) == (2, 2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) - x = rand(rng, Float32, 4, 4, 6, 1) + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + x = rand(rng, Float32, 4, 4, 6, 1) |> aType layer = Conv((3, 3), 6 => 2; groups=2) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (3, 3, 3, 2) @test size(layer(x, ps, st)[1]) == (2, 2, 2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) - x = rand(rng, Float32, 4, 4, 4, 6, 1) + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + x = rand(rng, Float32, 4, 4, 4, 6, 1) |> aType layer = Conv((3, 3, 3), 6 => 2; groups=2) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (3, 3, 3, 3, 2) @test size(layer(x, ps, st)[1]) == (2, 2, 2, 2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu # Test that we cannot ask for non-integer multiplication factors layer = Conv((2, 2), 3 => 10; groups=2) @@ -126,65 +129,68 @@ end @testset "Asymmetric Padding" begin layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) display(layer) - x = ones(Float32, 28, 28, 1, 1) - ps, st = Lux.setup(rng, layer) + x = ones(Float32, 28, 28, 1, 1) |> aType + ps, st = Lux.setup(rng, layer) .|> device ps.weight .= 1.0 ps.bias .= 0.0 - y_hat = layer(x, ps, st)[1][:, :, 1, 1] + y_hat = layer(x, ps, st)[1][:, :, 1, 1] |> Array @test size(y_hat) == (27, 29) - @test y_hat[1, 1] ≈ 6.0 - @test y_hat[2, 2] ≈ 9.0 - @test y_hat[end, 1] ≈ 4.0 - @test y_hat[1, end] ≈ 3.0 - @test y_hat[1, end - 1] ≈ 6.0 - @test y_hat[end, end] ≈ 2.0 - - run_JET_tests(layer, x, ps, st) + @test check_approx(y_hat[1, 1], 6.0) + @test check_approx(y_hat[2, 2], 9.0) + @test check_approx(y_hat[end, 1], 4.0) + @test check_approx(y_hat[1, end], 3.0) + @test check_approx(y_hat[1, end - 1], 6.0) + @test check_approx(y_hat[end, end], 2.0) + + @jet layer(x, ps, st) end @testset "Variable BitWidth Parameters" begin # https://github.com/FluxML/Flux.jl/issues/1421 layer = Conv((5, 5), 10 => 20, identity; init_weight=Base.randn, - init_bias=(rng, dims...) -> randn(rng, Float16, dims...)) + init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) display(layer) ps, st = Lux.setup(rng, layer) - @test ps.weight isa Array{Float64, 4} - @test ps.bias isa Array{Float16, 4} + @test ps.weight isa aType{Float64, 4} + @test ps.bias isa aType{Float16, 4} end @testset "Depthwise Conv" begin - x = randn(rng, Float32, 4, 4, 3, 2) + x = randn(rng, Float32, 4, 4, 3, 2) |> aType layer = Conv((2, 2), 3 => 15; groups=3) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test Lux.parameterlength(layer) == Lux.parameterlength(ps) @test size(layer(x, ps, st)[1], 3) == 15 - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Conv((2, 2), 3 => 9; groups=3) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1], 3) == 9 - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Conv((2, 2), 3 => 9; groups=3, use_bias=false) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test Lux.parameterlength(layer) == Lux.parameterlength(ps) @test size(layer(x, ps, st)[1], 3) == 9 - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu # Test that we cannot ask for non-integer multiplication factors layer = Conv((2, 2), 3 => 10; groups=3) @@ -193,76 +199,84 @@ end end @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) - x = ones(Float32, (k .+ 3)..., 1, 1) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType layer = Conv(k, 1 => 1; pad=Lux.SamePad()) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == size(x) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu layer = Conv(k, 1 => 1; pad=Lux.SamePad(), dilation=k .÷ 2) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == size(x) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu stride = 3 layer = Conv(k, 1 => 1; pad=Lux.SamePad(), stride=stride) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], stride) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "Conv with non quadratic window #700" begin x = zeros(Float32, 7, 7, 1, 1) x[4, 4, 1, 1] = 1 + x = x |> aType layer = Conv((3, 3), 1 => 1) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - y = zeros(eltype(ps.weight), 5, 5, 1, 1) + y = zeros(eltype(ps.weight), 5, 5, 1, 1) |> aType y[2:(end - 1), 2:(end - 1), 1, 1] = ps.weight - @test y ≈ layer(x, ps, st)[1] - run_JET_tests(layer, x, ps, st) + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) layer = Conv((3, 1), 1 => 1) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - y = zeros(eltype(ps.weight), 5, 7, 1, 1) + y = zeros(eltype(ps.weight), 5, 7, 1, 1) |> aType y[2:(end - 1), 4, 1, 1] = ps.weight - @test y ≈ layer(x, ps, st)[1] - run_JET_tests(layer, x, ps, st) + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) layer = Conv((1, 3), 1 => 1) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - y = zeros(eltype(ps.weight), 7, 5, 1, 1) + y = zeros(eltype(ps.weight), 7, 5, 1, 1) |> aType y[4, 2:(end - 1), 1, 1] = ps.weight - @test y ≈ layer(x, ps, st)[1] - run_JET_tests(layer, x, ps, st) + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - y = zeros(eltype(ps.weight), 7, 5, 1, 1) + y = zeros(eltype(ps.weight), 7, 5, 1, 1) |> aType y[4, 2:(end - 1), 1, 1] = ps.weight - @test y ≈ layer(x, ps, st)[1] - run_JET_tests(layer, x, ps, st) + @test check_approx(y, layer(x, ps, st)[1]) + + @jet layer(x, ps, st) end @testset "allow fast activation" begin @@ -282,7 +296,7 @@ end end end -@testset "Upsample" begin +@testset "$mode: Upsample" for (mode, aType, device, ongpu) in MODES @testset "Construction" begin @test_nowarn Upsample(:nearest; scale=2) @test_nowarn Upsample(:nearest; size=(64, 64)) @@ -306,16 +320,16 @@ end sizes = (nothing, (64, 64), (64, 32)) scales = (nothing, 2, (2, 1)) - for mode in modes, xsize in sizes, scale in scales + for umode in modes, xsize in sizes, scale in scales if !xor(isnothing(xsize), isnothing(scale)) continue end - layer = Upsample(mode; size=xsize, scale=scale) + layer = Upsample(umode; size=xsize, scale=scale) display(layer) - ps, st = Lux.setup(rng, layer) - x = zeros((32, 32, 3, 4)) + ps, st = Lux.setup(rng, layer) .|> device + x = zeros((32, 32, 3, 4)) |> aType - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) y, _ = layer(x, ps, st) if isnothing(scale) @@ -329,16 +343,16 @@ end sizes = (nothing, (64, 64, 64), (64, 32, 128)) scales = (nothing, 2, (2, 1, 1), (2, 2, 1)) - for mode in modes, xsize in sizes, scale in scales + for umode in modes, xsize in sizes, scale in scales if !xor(isnothing(xsize), isnothing(scale)) continue end - layer = Upsample(mode; size=xsize, scale=scale) + layer = Upsample(umode; size=xsize, scale=scale) display(layer) - ps, st = Lux.setup(rng, layer) - x = zeros((32, 32, 32, 3, 4)) + ps, st = Lux.setup(rng, layer) .|> device + x = zeros((32, 32, 32, 3, 4)) |> aType - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) y, _ = layer(x, ps, st) @@ -352,92 +366,99 @@ end end end -@testset "PixelShuffle" begin +@testset "$mode: PixelShuffle" for (mode, aType, device, ongpu) in MODES layer = PixelShuffle(2) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(rng, Float32, 3, 6, 3) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(rng, Float32, 3, 6, 3) |> aType y, st_ = layer(x, ps, st) - @test y isa Array{Float32, 3} + @test y isa aType{Float32, 3} @test size(y) == (6, 3, 3) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1e-3, rtol=1e-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1e-3 rtol=1e-3 layer = PixelShuffle(3) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 3, 4, 9, 3) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 3, 4, 9, 3) |> aType y, st_ = layer(x, ps, st) - @test y isa Array{Float32, 4} + @test y isa aType{Float32, 4} @test size(y) == (9, 12, 1, 3) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1e-3, rtol=1e-3) + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1e-3 rtol=1e-3 end -@testset "CrossCor" begin +@testset "$mode: CrossCor" for (mode, aType, device, ongpu) in MODES @testset "Asymmetric Padding" begin layer = CrossCor((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) display(layer) - x = ones(Float32, 28, 28, 1, 1) - ps, st = Lux.setup(rng, layer) + x = ones(Float32, 28, 28, 1, 1) |> aType + ps, st = Lux.setup(rng, layer) .|> device ps.weight .= 1.0 ps.bias .= 0.0 - y_hat = layer(x, ps, st)[1][:, :, 1, 1] + y_hat = layer(x, ps, st)[1][:, :, 1, 1] |> Array @test size(y_hat) == (27, 29) - @test y_hat[1, 1] ≈ 6.0 - @test y_hat[2, 2] ≈ 9.0 - @test y_hat[end, 1] ≈ 4.0 - @test y_hat[1, end] ≈ 3.0 - @test y_hat[1, end - 1] ≈ 6.0 - @test y_hat[end, end] ≈ 2.0 - - run_JET_tests(layer, x, ps, st) + @test check_approx(y_hat[1, 1], 6.0) + @test check_approx(y_hat[2, 2], 9.0) + @test check_approx(y_hat[end, 1], 4.0) + @test check_approx(y_hat[1, end], 3.0) + @test check_approx(y_hat[1, end - 1], 6.0) + @test check_approx(y_hat[end, end], 2.0) + + @jet layer(x, ps, st) end @testset "Variable BitWidth Parameters" begin # https://github.com/FluxML/Flux.jl/issues/1421 layer = CrossCor((5, 5), 10 => 20, identity; init_weight=Base.randn, - init_bias=(rng, dims...) -> randn(rng, Float16, dims...)) + init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) display(layer) ps, st = Lux.setup(rng, layer) - @test ps.weight isa Array{Float64, 4} - @test ps.bias isa Array{Float16, 4} + @test ps.weight isa aType{Float64, 4} + @test ps.bias isa aType{Float16, 4} end @testset "CrossCor SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) - x = ones(Float32, (k .+ 3)..., 1, 1) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType layer = CrossCor(k, 1 => 1; pad=Lux.SamePad()) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == size(x) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), dilation=k .÷ 2) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1]) == size(x) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 stride = 3 layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), stride=stride) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], stride) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) + + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 end @testset "allow fast activation" begin @@ -448,120 +469,121 @@ end end end -@testset "ConvTranspose" begin - x = randn(Float32, 5, 5, 1, 1) +@testset "$mode: ConvTranspose" for (mode, aType, device, ongpu) in MODES + x = randn(Float32, 5, 5, 1, 1) |> aType layer = Conv((3, 3), 1 => 1) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device y = layer(x, ps, st)[1] layer = ConvTranspose((3, 3), 1 => 1) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, y, ps, st; opt_broken=true) - @static if VERSION >= v"1.7" - # Inference broken in v1.6 - @inferred layer(y, ps, st) - end + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(y, ps, st) opt_broken=true + x_hat1 = layer(y, ps, st)[1] layer = ConvTranspose((3, 3), 1 => 1; use_bias=false) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, y, ps, st; opt_broken=true) - @static if VERSION >= v"1.7" - # Inference broken in v1.6 - @inferred layer(y, ps, st) - end + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(y, ps, st) opt_broken=true + x_hat2 = layer(y, ps, st)[1] @test size(x_hat1) == size(x_hat2) == size(x) layer = ConvTranspose((3, 3), 1 => 1) display(layer) - ps, st = Lux.setup(rng, layer) - x = rand(Float32, 5, 5, 1, 1) - run_JET_tests(layer, x, ps, st; opt_broken=true) - @static if VERSION >= v"1.7" - # Inference broken in v1.6 - @inferred layer(x, ps, st) - end - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + ps, st = Lux.setup(rng, layer) .|> device + x = rand(Float32, 5, 5, 1, 1) |> aType - x = rand(Float32, 5, 5, 2, 4) + @jet layer(x, ps, st) opt_broken=true + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = rand(Float32, 5, 5, 2, 4) |> aType layer = ConvTranspose((3, 3), 2 => 3) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, x, ps, st; opt_broken=true) - @static if VERSION >= v"1.7" - # Inference broken in v1.6 - @inferred layer(x, ps, st) - end - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 # test ConvTranspose supports groups argument - x = randn(Float32, 10, 10, 2, 3) + x = randn(Float32, 10, 10, 2, 3) |> aType layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad()) display(layer1) - ps1, st1 = Lux.setup(rng, layer1) + ps1, st1 = Lux.setup(rng, layer1) .|> device @test size(ps1.weight) == (3, 3, 4, 2) @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) layer2 = ConvTranspose((3, 3), 2 => 4; groups=2, pad=SamePad()) display(layer2) - ps2, st2 = Lux.setup(rng, layer2) + ps2, st2 = Lux.setup(rng, layer2) .|> device @test size(ps2.weight) == (3, 3, 2, 2) @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) - test_gradient_correctness_fdm((x, ps) -> sum(layer1(x, ps, st1)[1]), x, ps1; - atol=1.0f-3, rtol=1.0f-3) - test_gradient_correctness_fdm((x, ps) -> sum(layer2(x, ps, st2)[1]), x, ps2; - atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(layer1(x, ps, st1))) + @eval @test_gradients $__f $x $ps1 gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + __f = (x, ps) -> sum(first(layer2(x, ps, st2))) + @eval @test_gradients $__f $x $ps2 gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - x = randn(Float32, 10, 2, 1) + x = randn(Float32, 10, 2, 1) |> aType layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, x, ps, st; opt_broken=true) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + @test size(layer(x, ps, st)[1]) == (10, 4, 1) @test length(ps.weight) == 3 * (2 * 4) / 2 - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) - x = randn(Float32, 10, 11, 4, 2) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 11, 4, 2) |> aType layer = ConvTranspose((3, 5), 4 => 4; pad=SamePad(), groups=4) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, x, ps, st; opt_broken=true) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) - x = randn(Float32, 10, 11, 4, 2) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 11, 4, 2) |> aType layer = ConvTranspose((3, 5), 4 => 4, tanh; pad=SamePad(), groups=4) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, x, ps, st; opt_broken=true) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, - rtol=1.0f-3) - x = randn(Float32, 10, 11, 12, 3, 2) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 + + x = randn(Float32, 10, 11, 12, 3, 2) |> aType layer = ConvTranspose((3, 5, 3), 3 => 6; pad=SamePad(), groups=3) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, x, ps, st; opt_broken=true) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - x = randn(Float32, 10, 11, 12, 3, 2) + x = randn(Float32, 10, 11, 12, 3, 2) |> aType layer = ConvTranspose((3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3) display(layer) - ps, st = Lux.setup(rng, layer) - run_JET_tests(layer, x, ps, st; opt_broken=true) + ps, st = Lux.setup(rng, layer) .|> device + + @jet layer(x, ps, st) opt_broken=true @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 diff --git a/test/layers/dropout.jl b/test/layers/dropout.jl index 60004dbde1..9531715ae6 100644 --- a/test/layers/dropout.jl +++ b/test/layers/dropout.jl @@ -5,84 +5,91 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "Dropout" begin for p in (0.5f0, 0.5) - layer = Dropout(p) - display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(Float32, 5, 2) - - x_, st_ = layer(x, ps, st) - x__, st__ = layer(x, ps, st) - x___, st___ = layer(x_, ps, st_) - - @test st_.rng != st.rng - @test st_.rng == st__.rng - @test x_ == x__ - @test x_ != x___ - - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - - st = Lux.testmode(st) - - @test first(layer(x, ps, st)) == x -end end - -@testset "AlphaDropout" begin for p in (0.5f0, 0.5) - layer = AlphaDropout(p) - display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(Float32, 5, 2) - - x_, st_ = layer(x, ps, st) - x__, st__ = layer(x, ps, st) - x___, st___ = layer(x_, ps, st_) - - @test st_.rng != st.rng - @test st_.rng == st__.rng - @test x_ == x__ - @test x_ != x___ - - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - - st = Lux.testmode(st) - - @test first(layer(x, ps, st)) == x -end end - -@testset "VariationalHiddenDropout" begin for p in (0.5f0, 0.5) - layer = VariationalHiddenDropout(p) - display(layer) - ps, st = Lux.setup(rng, layer) - x = randn(Float32, 5, 2) - - x_, st_ = layer(x, ps, st) - x__, st__ = layer(x, ps, st) - x___, st___ = layer(x_, ps, st_) - - @test st_.rng != st.rng - @test st_.rng == st__.rng - @test st_.mask == st__.mask - @test x_ == x__ - @test x_ != x___ - - run_JET_tests(layer, x, ps, st) - run_JET_tests(layer, x, ps, st_) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st_)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - - st__ = Lux.update_state(st_, :update_mask, Val(true)) - x___, st___ = layer(x, ps, st__) - - @test st___.mask != st__.mask - @test x___ != x_ - - run_JET_tests(layer, x, ps, st__) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol=1.0f-3, - rtol=1.0f-3) -end end +@testset "$mode: Dropout" for (mode, aType, device, ongpu) in MODES + for p in (0.5f0, 0.5) + layer = Dropout(p) + display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(Float32, 5, 2) |> aType + + x_, st_ = layer(x, ps, st) + x__, st__ = layer(x, ps, st) + x___, st___ = layer(x_, ps, st_) + + @test st_.rng != st.rng + @test st_.rng == st__.rng + @test x_ == x__ + @test x_ != x___ + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + st = Lux.testmode(st) + + @test first(layer(x, ps, st)) == x + end +end + +@testset "$mode: AlphaDropout" for (mode, aType, device, ongpu) in MODES + for p in (0.5f0, 0.5) + layer = AlphaDropout(p) + display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(Float32, 5, 2) |> aType + + x_, st_ = layer(x, ps, st) + x__, st__ = layer(x, ps, st) + x___, st___ = layer(x_, ps, st_) + + @test st_.rng != st.rng + @test st_.rng == st__.rng + @test x_ == x__ + @test x_ != x___ + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + st = Lux.testmode(st) + + @test first(layer(x, ps, st)) == x + end +end + +@testset "$mode: VariationalHiddenDropout" for (mode, aType, device, ongpu) in MODES + for p in (0.5f0, 0.5) + layer = VariationalHiddenDropout(p) + display(layer) + ps, st = Lux.setup(rng, layer) .|> device + x = randn(Float32, 5, 2) |> aType + + x_, st_ = layer(x, ps, st) + x__, st__ = layer(x, ps, st) + x___, st___ = layer(x_, ps, st_) + + @test st_.rng != st.rng + @test st_.rng == st__.rng + @test st_.mask == st__.mask + @test x_ == x__ + @test x_ != x___ + + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + @jet layer(x, ps, st_) + __f = x -> sum(first(layer(x, ps, st_))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + + st__ = Lux.update_state(st_, :update_mask, Val(true)) + x___, st___ = layer(x, ps, st__) + + @test st___.mask != st__.mask + @test x___ != x_ + + @jet layer(x, ps, st__) + __f = x -> sum(first(layer(x, ps, st__))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + end +end diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index 1c053d2a5e..a70fc4ab2b 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -5,21 +5,22 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "BatchNorm" begin +@testset "$mode: BatchNorm" for (mode, aType, device, ongpu) in MODES m = BatchNorm(2) x = [1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0] + 2.0f0 4.0f0 6.0f0] |> aType display(m) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device @test Lux.parameterlength(m) == Lux.parameterlength(ps) @test Lux.statelength(m) == Lux.statelength(st) - @test ps.bias == [0, 0] # init_bias(2) - @test ps.scale == [1, 1] # init_scale(2) + @test ps.bias == [0, 0] |> aType # init_bias(2) + @test ps.scale == [1, 1] |> aType # init_scale(2) y, st_ = pullback(m, x, ps, st)[1] - @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol=1.0e-5) + st_ = st_ |> cpu + @test check_approx(Array(y), [-1.22474 0 1.22474; -1.22474 0 1.22474]; atol=1.0e-5) # julia> x # 2×3 Array{Float64,2}: # 1.0 3.0 5.0 @@ -32,73 +33,71 @@ Random.seed!(rng, 0) # ∴ update rule with momentum: # .1 * 3 + 0 = .3 # .1 * 4 + 0 = .4 - @test st_.running_mean ≈ reshape([0.3, 0.4], 2, 1) + @test check_approx(st_.running_mean, reshape([0.3, 0.4], 2, 1)) # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] # 2×1 Array{Float64,2}: # 1.3 # 1.3 - @test st_.running_var ≈ - 0.1 .* var(x; dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] + @test check_approx(st_.running_var, + 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ + 0.9 .* [1.0, 1.0]) st_ = Lux.testmode(st_) - x_ = m(x, ps, st_)[1] - @test isapprox(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) + x_ = m(x, ps, st_)[1] |> cpu + @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) - @inferred first(m(x, ps, st)) - - run_JET_tests(m, x, ps, st) - - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; atol=1.0f-3, - rtol=1.0f-3) + @jet m(x, ps, st) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu for affine in (true, false) m = BatchNorm(2; affine, track_stats=false) - x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] + x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] |> aType display(m) - ps, st = Lux.setup(rng, m) - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) + ps, st = Lux.setup(rng, m) .|> device + + @jet m(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, - rtol=1.0f-3) + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end # with activation function m = BatchNorm(2, sigmoid; affine) x = [1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0] + 2.0f0 4.0f0 6.0f0] |> aType display(m) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device st = Lux.testmode(st) y, st_ = m(x, ps, st) - @test isapprox(y, - sigmoid.((x .- st_.running_mean) ./ - sqrt.(st_.running_var .+ m.epsilon)), atol=1.0e-7) - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) + @test check_approx(y, + sigmoid.((x .- st_.running_mean) ./ + sqrt.(st_.running_var .+ m.epsilon)), atol=1.0e-7) + + @jet m(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, - rtol=1.0f-3) + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end m = BatchNorm(32; affine) - x = randn(Float32, 416, 416, 32, 1) + x = randn(Float32, 416, 416, 32, 1) |> aType display(m) ps, st = Lux.setup(rng, m) st = Lux.testmode(st) m(x, ps, st) @test (@allocated m(x, ps, st)) < 100_000_000 - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) + + @jet m(x, ps, st) end @testset "allow fast activation" begin @@ -109,23 +108,24 @@ Random.seed!(rng, 0) end end -@testset "GroupNorm" begin +@testset "$mode: GroupNorm" for (mode, aType, device, ongpu) in MODES # begin tests squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions m = GroupNorm(4, 2; track_stats=true) sizes = (3, 4, 2) - x = reshape(collect(1:prod(sizes)), sizes) + x = reshape(collect(1:prod(sizes)), sizes) |> aType display(m) x = Float32.(x) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device @test Lux.parameterlength(m) == Lux.parameterlength(ps) @test Lux.statelength(m) == Lux.statelength(st) - @test ps.bias == [0, 0, 0, 0] # init_bias(32) - @test ps.scale == [1, 1, 1, 1] # init_scale(32) + @test ps.bias == [0, 0, 0, 0] |> aType # init_bias(32) + @test ps.scale == [1, 1, 1, 1] |> aType # init_scale(32) y, st_ = pullback(m, x, ps, st)[1] + y = y |> Array # julia> x # [:, :, 1] = @@ -152,68 +152,67 @@ end # ∴ update rule with momentum: # (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95 # (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55 - @test st_.running_mean ≈ [0.95, 1.55] + @test check_approx(st_.running_mean, aType([0.95, 1.55])) n = prod(size(x)) ÷ m.groups ÷ size(x)[end] corr = n / (n - 1) z = reshape(x, 3, 2, 2, 2) variance = var(z; dims=(1, 2), corrected=false) - @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims=4)) .+ 0.9 * 1 + @test check_approx(st_.running_var, 0.1 * corr * vec(mean(variance; dims=4)) .+ 0.9 * 1) st__ = Lux.testmode(st_) y, st__ = m(x, ps, st__) out = (z .- reshape(st_.running_mean, 1, 1, 2, 1)) ./ sqrt.(reshape(st_.running_var, 1, 1, 2, 1) .+ 1.0f-5) - @test y≈reshape(out, size(x)) atol=1.0e-5 + @test check_approx(y, reshape(out, size(x)); atol=1.0e-5) - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) - test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol=1.0f-3, - rtol=1.0f-3) + @jet m(x, ps, st) + __f = ps -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu for affine in (true, false) m = GroupNorm(2, 2; affine, track_stats=false) - x = randn(rng, Float32, 3, 2, 1) + x = randn(rng, Float32, 3, 2, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) + ps, st = Lux.setup(rng, m) .|> device + + @jet m(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, - rtol=1.0f-3) + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end # with activation function m = GroupNorm(2, 2, sigmoid; affine) - x = randn(rng, Float32, 3, 2, 1) + x = randn(rng, Float32, 3, 2, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device st = Lux.testmode(st) y, st_ = m(x, ps, st) - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) + @jet m(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(m(x, ps, st))), x; atol=1.0f-3, - rtol=1.0f-3) + __f = x -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end m = GroupNorm(32, 16; affine) - x = randn(rng, Float32, 416, 416, 32, 1) + x = randn(rng, Float32, 416, 416, 32, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device st = Lux.testmode(st) m(x, ps, st) + @test (@allocated m(x, ps, st)) < 100_000_000 - @inferred first(m(x, ps, st)) - run_JET_tests(m, x, ps, st) + + @jet m(x, ps, st) end @test_throws AssertionError GroupNorm(5, 2) @@ -230,9 +229,9 @@ end @test_deprecated GroupNorm(4, 2; track_stats=false, momentum=0.3f0) end -@testset "WeightNorm" begin +@testset "$mode: WeightNorm" for (mode, aType, device, ongpu) in MODES @testset "_norm_except" begin - z = randn(rng, Float32, 3, 3, 4, 2) + z = randn(rng, Float32, 3, 3, 4, 2) |> aType @test size(Lux._norm(z; dims=(1, 2))) == (1, 1, 4, 2) @test size(Lux._norm_except(z; dims=1)) == (3, 1, 1, 1) @@ -240,8 +239,9 @@ end @test size(Lux._norm_except(z; dims=(1, 2))) == (3, 3, 1, 1) @test Lux._norm_except(z; dims=(1, 2)) == Lux._norm(z; dims=(3, 4)) - run_JET_tests(Lux._norm_except, z) - run_JET_tests(x -> Lux._norm_except(x; dims=(3, 4)), z) + @jet Lux._norm_except(z) + __f = z -> sum(Lux._norm_except(z; dims=(3, 4))) + @jet __f(z) end @testset "Conv" begin @@ -249,39 +249,39 @@ end wn = WeightNorm(c, (:weight, :bias)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 3, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = ps -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu wn = WeightNorm(c, (:weight,)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 3, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu wn = WeightNorm(c, (:weight, :bias), (2, 2)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 3, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu wn = WeightNorm(c, (:weight,), (2,)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 3, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 3, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "Dense" begin @@ -289,39 +289,39 @@ end wn = WeightNorm(d, (:weight, :bias)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = ps -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu wn = WeightNorm(d, (:weight,)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu wn = WeightNorm(d, (:weight, :bias), (2, 2)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu wn = WeightNorm(d, (:weight,), (2,)) display(wn) - ps, st = Lux.setup(rng, wn) - x = randn(rng, Float32, 3, 1) + ps, st = Lux.setup(rng, wn) .|> device + x = randn(rng, Float32, 3, 1) |> aType - run_JET_tests(wn, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps; - atol=1.0f-3, rtol=1.0f-3) + @jet wn(x, ps, st) + __f = (x, ps) -> sum(first(wn(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end # See https://github.com/avik-pal/Lux.jl/issues/95 @@ -344,47 +344,45 @@ end end end -@testset "LayerNorm" begin - x = randn(rng, Float32, 3, 3, 3, 2) +@testset "$mode: LayerNorm" for (mode, aType, device, ongpu) in MODES + x = randn(rng, Float32, 3, 3, 3, 2) |> aType for bshape in ((3, 3, 3), (1, 3, 1), (3, 1, 3)) for affine in (true, false) ln = LayerNorm(bshape; affine) display(ln) - ps, st = Lux.setup(rng, ln) + ps, st = Lux.setup(rng, ln) .|> device - @inferred first(ln(x, ps, st)) y, st_ = ln(x, ps, st) - @test isapprox(mean(y), 0; atol=1.0f-3, rtol=1.0f-3) - @test isapprox(std(y), 1; atol=1.0f-2, rtol=1.0f-2) + @test check_approx(mean(y), 0; atol=1.0f-3, rtol=1.0f-3) + @test check_approx(std(y), 1; atol=1.0f-2, rtol=1.0f-2) - run_JET_tests(ln, x, ps, st) + @jet ln(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(ln(x, ps, st))), x, ps; - atol=1.0f-1, rtol=1.0f-1) + __f = (x, ps) -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(ln(x, ps, st))), x; - atol=1.0f-1, rtol=1.0f-1) + __f = x -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end for act in (sigmoid, tanh) ln = LayerNorm(bshape, act; affine) display(ln) - ps, st = Lux.setup(rng, ln) + ps, st = Lux.setup(rng, ln) .|> device - @inferred first(ln(x, ps, st)) y, st_ = ln(x, ps, st) - run_JET_tests(ln, x, ps, st) + @jet ln(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(ln(x, ps, st))), x, - ps; atol=1.0f-1, rtol=1.0f-1) + __f = (x, ps) -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(ln(x, ps, st))), x; - atol=1.0f-1, rtol=1.0f-1) + __f = x -> sum(first(ln(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end end end @@ -398,43 +396,42 @@ end end end -@testset "InstanceNorm" begin +@testset "$mode: InstanceNorm" for (mode, aType, device, ongpu) in MODES for x in (randn(rng, Float32, 3, 3, 3, 2), randn(rng, Float32, 3, 3, 2), randn(rng, Float32, 3, 3, 3, 3, 2)) + x = x |> aType for affine in (true, false) layer = InstanceNorm(3; affine) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - @inferred first(layer(x, ps, st)) y, st_ = layer(x, ps, st) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(layer(x, ps, st))), x, - ps; atol=1.0f-1, rtol=1.0f-1) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(layer(x, ps, st))), x; - atol=1.0f-1, rtol=1.0f-1) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end for act in (sigmoid, tanh) layer = InstanceNorm(3, act; affine) display(layer) - ps, st = Lux.setup(rng, layer) + ps, st = Lux.setup(rng, layer) .|> device - @inferred first(layer(x, ps, st)) y, st_ = layer(x, ps, st) - run_JET_tests(layer, x, ps, st) + @jet layer(x, ps, st) if affine - test_gradient_correctness_fdm((x, ps) -> sum(first(layer(x, ps, st))), - x, ps; atol=1.0f-1, rtol=1.0f-1) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu else - test_gradient_correctness_fdm(x -> sum(first(layer(x, ps, st))), x; - atol=1.0f-1, rtol=1.0f-1) + __f = x -> sum(first(layer(x, ps, st))) + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end end end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 08d5b553ca..1b30a30eee 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -5,18 +5,18 @@ include("../test_utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "RNNCell" begin +@testset "$mode: RNNCell" for (mode, aType, device, ongpu) in MODES for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), RNNCell(3 => 5, tanh; use_bias=false), RNNCell(3 => 5, identity; use_bias=false), RNNCell(3 => 5, identity; use_bias=false, train_state=false)) display(rnncell) - ps, st = Lux.setup(rng, rnncell) - x = randn(rng, Float32, 3, 2) + ps, st = Lux.setup(rng, rnncell) .|> device + x = randn(rng, Float32, 3, 2) |> aType (y, carry), st_ = Lux.apply(rnncell, x, ps, st) - run_JET_tests(rnncell, x, ps, st) - run_JET_tests(rnncell, (x, carry), ps, st_) + @jet rnncell(x, ps, st) + @jet rnncell((x, carry), ps, st) function loss_loop_rnncell(p) (y, carry), st_ = rnncell(x, p, st) @@ -28,7 +28,7 @@ Random.seed!(rng, 0) @test_throws ErrorException ps.train_state - test_gradient_correctness_fdm(loss_loop_rnncell, ps; atol=1e-2, rtol=1e-2) + @eval @test_gradients $loss_loop_rnncell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "Trainable hidden states" begin for rnncell in (RNNCell(3 => 5, identity; @@ -39,12 +39,12 @@ Random.seed!(rng, 0) train_state=true)) rnn_no_trainable_state = RNNCell(3 => 5, identity; use_bias=false, train_state=false) - x = randn(rng, Float32, 3, 2) - _ps, _st = Lux.setup(rng, rnn_no_trainable_state) + x = randn(rng, Float32, 3, 2) |> aType + _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> device (_y, _carry), _ = Lux.apply(rnn_no_trainable_state, x, _ps, _st) rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, rnncell) + ps, st = Lux.setup(rng, rnncell) .|> device ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(rnncell, x, ps, st) @test carry == _carry @@ -62,16 +62,16 @@ Random.seed!(rng, 0) end end -@testset "LSTMCell" begin +@testset "$mode: LSTMCell" for (mode, aType, device, ongpu) in MODES for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), LSTMCell(3 => 5; use_bias=false)) display(lstmcell) - ps, st = Lux.setup(rng, lstmcell) - x = randn(rng, Float32, 3, 2) + ps, st = Lux.setup(rng, lstmcell) .|> device + x = randn(rng, Float32, 3, 2) |> aType (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) - run_JET_tests(lstmcell, x, ps, st) - run_JET_tests(lstmcell, (x, carry), ps, st_) + @jet lstmcell(x, ps, st) + @jet lstmcell((x, carry), ps, st) function loss_loop_lstmcell(p) (y, carry), st_ = lstmcell(x, p, st) @@ -81,20 +81,20 @@ end return sum(abs2, y) end - test_gradient_correctness_fdm(loss_loop_lstmcell, ps; atol=1e-2, rtol=1e-2) + @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu @test_throws ErrorException ps.train_state @test_throws ErrorException ps.train_memory end @testset "Trainable hidden states" begin - x = randn(rng, Float32, 3, 2) + x = randn(rng, Float32, 3, 2) |> aType _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) - _ps, _st = Lux.setup(rng, _lstm) + _ps, _st = Lux.setup(rng, _lstm) .|> device (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) - ps, st = Lux.setup(rng, lstm) + ps, st = Lux.setup(rng, lstm) .|> device ps = _ps (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -105,7 +105,7 @@ end @test_throws ErrorException gs.memory lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) - ps, st = Lux.setup(rng, lstm) + ps, st = Lux.setup(rng, lstm) .|> device ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -116,7 +116,7 @@ end @test_throws ErrorException gs.memory lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) - ps, st = Lux.setup(rng, lstm) + ps, st = Lux.setup(rng, lstm) .|> device ps = merge(_ps, (memory=ps.memory,)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -127,7 +127,7 @@ end @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) + ps, st = Lux.setup(rng, lstm) .|> device ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -138,7 +138,7 @@ end @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) + ps, st = Lux.setup(rng, lstm) .|> device ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) @@ -149,16 +149,16 @@ end end end -@testset "GRUCell" begin +@testset "$mode: GRUCell" for (mode, aType, device, ongpu) in MODES for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), GRUCell(3 => 5; use_bias=false)) display(grucell) - ps, st = Lux.setup(rng, grucell) - x = randn(rng, Float32, 3, 2) + ps, st = Lux.setup(rng, grucell) .|> device + x = randn(rng, Float32, 3, 2) |> aType (y, carry), st_ = Lux.apply(grucell, x, ps, st) - run_JET_tests(grucell, x, ps, st) - run_JET_tests(grucell, (x, carry), ps, st_) + @jet grucell(x, ps, st) + @jet grucell((x, carry), ps, st) function loss_loop_grucell(p) (y, carry), st_ = grucell(x, p, st) @@ -168,19 +168,19 @@ end return sum(abs2, y) end - test_gradient_correctness_fdm(loss_loop_grucell, ps; atol=1e-2, rtol=1e-2) + @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu @test_throws ErrorException ps.train_state end @testset "Trainable hidden states" begin - x = randn(rng, Float32, 3, 2) + x = randn(rng, Float32, 3, 2) |> aType _gru = GRUCell(3 => 5; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, _gru) + _ps, _st = Lux.setup(rng, _gru) .|> device (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) gru = GRUCell(3 => 5; use_bias=false, train_state=false) - ps, st = Lux.setup(rng, gru) + ps, st = Lux.setup(rng, gru) .|> device ps = _ps (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -190,7 +190,7 @@ end @test_throws ErrorException gs.hidden_state gru = GRUCell(3 => 5; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, gru) + ps, st = Lux.setup(rng, gru) .|> device ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -199,7 +199,7 @@ end @test !isnothing(gs.hidden_state) gru = GRUCell(3 => 5; use_bias=true, train_state=true) - ps, st = Lux.setup(rng, gru) + ps, st = Lux.setup(rng, gru) .|> device ps = merge(_ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -209,66 +209,75 @@ end end end -@testset "StatefulRecurrentCell" begin for _cell in (RNNCell, LSTMCell, GRUCell), - use_bias in (true, false), - train_state in (true, false) +@testset "$mode: StatefulRecurrentCell" for (mode, aType, device, ongpu) in MODES + for _cell in (RNNCell, LSTMCell, GRUCell), + use_bias in (true, false), + train_state in (true, false) - cell = _cell(3 => 5; use_bias, train_state) - rnn = StatefulRecurrentCell(cell) - display(rnn) - x = randn(rng, Float32, 3, 2) - ps, st = Lux.setup(rng, rnn) + cell = _cell(3 => 5; use_bias, train_state) + rnn = StatefulRecurrentCell(cell) + display(rnn) + x = randn(rng, Float32, 3, 2) |> aType + ps, st = Lux.setup(rng, rnn) .|> device - y, st_ = rnn(x, ps, st) + y, st_ = rnn(x, ps, st) - run_JET_tests(rnn, x, ps, st) - run_JET_tests(rnn, x, ps, st_) + @jet rnn(x, ps, st) + @jet rnn(x, ps, st_) - @test size(y) == (5, 2) - @test st.carry === nothing - @test st_.carry !== nothing + @test size(y) == (5, 2) + @test st.carry === nothing + @test st_.carry !== nothing - st__ = Lux.update_state(st, :carry, nothing) - @test st__.carry === nothing + st__ = Lux.update_state(st, :carry, nothing) + @test st__.carry === nothing - function loss_loop_rnn(p) - y, st_ = rnn(x, p, st) - for i in 1:10 - y, st_ = rnn(x, p, st_) + function loss_loop_rnn(p) + y, st_ = rnn(x, p, st) + for i in 1:10 + y, st_ = rnn(x, p, st_) + end + return sum(abs2, y) end - return sum(abs2, y) - end - - test_gradient_correctness_fdm(loss_loop_rnn, ps; atol=1e-2, rtol=1e-2) -end end - -@testset "Recurrence" begin for _cell in (RNNCell, LSTMCell, GRUCell), - use_bias in (true, false), - train_state in (true, false) - cell = _cell(3 => 5; use_bias, train_state) - rnn = Recurrence(cell) - rnn_seq = Recurrence(cell; return_sequence=true) - display(rnn) - - # Batched Time Series - for x in (randn(rng, Float32, 3, 4, 2), Tuple(randn(rng, Float32, 3, 2) for _ in 1:4), - [randn(rng, Float32, 3, 2) for _ in 1:4]) - ps, st = Lux.setup(rng, rnn) - y, st_ = rnn(x, ps, st) - y_, st__ = rnn_seq(x, ps, st) - run_JET_tests(rnn, x, ps, st) - run_JET_tests(rnn_seq, x, ps, st) - - @test size(y) == (5, 2) - @test length(y_) == 4 - @test all(x -> size(x) == (5, 2), y_) + @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + end +end - test_gradient_correctness_fdm(p -> sum(rnn(x, p, st)[1]), ps; atol=1e-2, rtol=1e-2) - test_gradient_correctness_fdm(p -> sum(Base.Fix1(sum, abs2), rnn_seq(x, p, st)[1]), - ps; atol=1e-2, rtol=1e-2) +@testset "$mode: Recurrence" for (mode, aType, device, ongpu) in MODES + for _cell in (RNNCell, LSTMCell, GRUCell), + use_bias in (true, false), + train_state in (true, false) + + cell = _cell(3 => 5; use_bias, train_state) + rnn = Recurrence(cell) + rnn_seq = Recurrence(cell; return_sequence=true) + display(rnn) + + # Batched Time Series + for x in (randn(rng, Float32, 3, 4, 2), + Tuple(randn(rng, Float32, 3, 2) for _ in 1:4), + [randn(rng, Float32, 3, 2) for _ in 1:4]) + x = x |> aType + ps, st = Lux.setup(rng, rnn) .|> device + y, st_ = rnn(x, ps, st) + y_, st__ = rnn_seq(x, ps, st) + + @jet rnn(x, ps, st) + @jet rnn_seq(x, ps, st) + + @test size(y) == (5, 2) + @test length(y_) == 4 + @test all(x -> size(x) == (5, 2), y_) + + __f = p -> sum(first(rnn(x, p, st))) + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + + __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + end end -end end +end @testset "multigate" begin x = rand(6, 5) diff --git a/test/nnlib.jl b/test/nnlib.jl index 84483e9f0a..3aa68351ae 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -1,15 +1,15 @@ -using CUDA, Lux, Random, Test +using Lux, Random, Test include("test_utils.jl") -@testset "Elementwise Operation Dispatches" begin +@testset "$mode: Elementwise Operation Dispatches" for (mode, aType, device, ongpu) in MODES rng = Random.default_rng() Random.seed!(rng, 0) custom_activation(x) = abs(x) for T in [Float64, Float32, ComplexF64, ComplexF32] - x = randn(rng, T, 10, 5, 2) - y = randn(rng, T, 10, 1, 2) + x = randn(rng, T, 10, 5, 2) |> aType + y = randn(rng, T, 10, 1, 2) |> aType # On CPU the fallback should always work @test Lux.elementwise_add(x, y) == x .+ y @@ -19,19 +19,7 @@ include("test_utils.jl") if T <: Real # Gradient for complex outputs are not defined - test_gradient_correctness_fdm(sum ∘ Lux.elementwise_add, x, y) - end - - # On GPU try to use CUDNN - if CUDA.functional() - x_g = x |> gpu - y_g = y |> gpu - - @test Lux.elementwise_add(x_g, y_g) == x_g .+ y_g - @test Lux.elementwise_mul(x_g, y_g) == x_g .* y_g - @test Lux.applyactivation(tanh, x_g) == tanh.(x_g) - # Custom Activation test - @test Lux.applyactivation(custom_activation, x_g) == custom_activation.(x_g) + @eval @test_gradients $(sum ∘ Lux.elementwise_add) $x $y gpu_testing=$ongpu end # Deprecated Functionality (Remove in v0.5) diff --git a/test/runtests.jl b/test/runtests.jl index b983db92d8..98ca814cad 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,8 +18,6 @@ using SafeTestsets, Test @time @safetestset "NNlib" begin include("nnlib.jl") end - @time @safetestset "Automatic Differentiation" begin include("autodiff.jl") end - @testset "Experimental" begin @time @safetestset "Map" begin include("contrib/map.jl") end @time @safetestset "Training" begin include("contrib/training.jl") end diff --git a/test/test_utils.jl b/test/test_utils.jl index cd99fe16b9..8db8f55868 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,107 +1,31 @@ -using ComponentArrays, FiniteDifferences, Functors, Lux, Optimisers, Random, Test -import ReverseDiff, Tracker, Zygote +using Lux, LuxCore, LuxLib, LuxTestUtils, Test, Zygote +using LuxCUDA # CUDA Support +using LuxTestUtils: @jet, @test_gradients, check_approx -try - using JET -catch - @warn "JET not not precompiling. All JET tests will be skipped." maxlog=1 - global test_call(args...; kwargs...) = nothing - global test_opt(args...; kwargs...) = nothing -end - -function Base.isapprox(x, y; kwargs...) - @warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead." - return x == y -end - -function Base.isapprox(x::Tuple, y::Tuple; kwargs...) - return all(isapprox.(x, y; kwargs...)) -end - -function Base.isapprox(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) - return isapprox(x.rule, y.rule; kwargs...) && isapprox(x.state, y.state; kwargs...) -end - -function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; - kwargs...) where {fields} - checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) - checkapprox(t::Tuple{Nothing, Nothing}) = true - return all(checkapprox, zip(values(nt1), values(nt2))) -end - -function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} - checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) - checkapprox(t::Tuple{Nothing, Nothing}) = true - return all(checkapprox, zip(t1, t2)) -end - -Base.isapprox(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 -Base.isapprox(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 -Base.isapprox(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 -Base.isapprox(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 -Base.isapprox(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 -Base.isapprox(::Nothing, v::Tuple; kwargs...) = length(v) == 0 -Base.isapprox(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 -Base.isapprox(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 -Base.isapprox(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 -Base.isapprox(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 +const GROUP = get(ENV, "GROUP", "All") -_named_tuple(x::ComponentArray) = NamedTuple(x) -_named_tuple(x) = x +cpu_testing() = GROUP == "All" || GROUP == "CPU" +cuda_testing() = (GROUP == "All" || GROUP == "CUDA") && LuxCUDA.functional() +amdgpu_testing() = (GROUP == "All" || GROUP == "AMDGPU") # && LuxAMDGPU.functional() -# Test the gradients generated using AD against the gradients generated using Finite Differences -function test_gradient_correctness_fdm(f::Function, args...; reversediff_broken=false, - kwargs...) - gs_ad_zygote = Zygote.gradient(f, args...) +const MODES = begin + # Mode, Array Type, Device Function, GPU? + cpu_mode = ("CPU", Array, cpu, false) + cuda_mode = ("CUDA", CuArray, gpu, true) - gs_ad_tracker = Tracker.gradient(f, args...) - - # ReverseDiff requires AbstractArray inputs - if any(!Base.Fix2(isa, AbstractArray), args) - rdiff_skipped = true - gs_ad_rdiff = fmap(zero, args) - else - rdiff_skipped = false - gs_ad_rdiff = _named_tuple.(ReverseDiff.gradient(f, ComponentArray.(args))) - end - - gs_fdm = _named_tuple.(FiniteDifferences.grad(FiniteDifferences.central_fdm(5, 1), f, - ComponentArray.(args)...)) - - for (g_ad_zygote, g_ad_tracker, g_ad_rdiff, g_fdm) in zip(gs_ad_zygote, gs_ad_tracker, - gs_ad_rdiff, gs_fdm) - @test isapprox(g_ad_zygote, g_fdm; kwargs...) - @test isapprox(Tracker.data(g_ad_tracker), g_ad_zygote; kwargs...) - if !rdiff_skipped - if reversediff_broken - @test_broken isapprox(g_ad_rdiff, g_ad_zygote; kwargs...) - else - @test isapprox(g_ad_rdiff, g_ad_zygote; kwargs...) - end - end - end + modes = [] + cpu_testing() && push!(modes, cpu_mode) + cuda_testing() && push!(modes, cuda_mode) + modes end # Some Helper Functions -function run_fwd_and_bwd(model, input, ps, st) - y, pb = Zygote.pullback(p -> model(input, p, st)[1], ps) - gs = pb(ones(eltype(y), size(y))) - # if we make it to here with no error, success! - return true -end - -function run_model(m::Lux.AbstractExplicitLayer, x, mode=:test) - ps, st = Lux.setup(Random.default_rng(), m) - if mode == :test - st = Lux.testmode(st) - end - return Lux.apply(m, x, ps, st)[1] -end - -# JET Tests -function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs...) - @static if VERSION >= v"1.7" - test_call(f, typeof.(args); broken=call_broken, target_modules=(Lux,)) - test_opt(f, typeof.(args); broken=opt_broken, target_modules=(Lux,)) +function get_default_rng(mode::String) + if mode == "CPU" + return Random.default_rng() + elseif mode == "CUDA" + return CUDA.RNG() + else + error("Unknown mode: $mode") end end diff --git a/test/utils.jl b/test/utils.jl index b9c32a449c..303dbd9a78 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,4 @@ -using Lux, ComponentArrays, CUDA, Functors, ReverseDiff, Random, Optimisers, Zygote, Test +using Lux, ComponentArrays, LuxCUDA, Functors, Random, Optimisers, Zygote, Test using Statistics: std include("test_utils.jl") @@ -19,15 +19,10 @@ Random.seed!(rng, 0) @test Lux._nfan(4, 5, 6) == 4 .* (5, 6) end -@testset "replicate" begin - @test randn(rng, 10, 2) != randn(rng, 10, 2) - @test randn(Lux.replicate(rng), 10, 2) == randn(Lux.replicate(rng), 10, 2) - - if CUDA.functional() - curng = CUDA.RNG() - @test randn(curng, 10, 2) != randn(curng, 10, 2) - @test randn(Lux.replicate(curng), 10, 2) == randn(Lux.replicate(curng), 10, 2) - end +@testset "$mode: replicate" for (mode, aType, device, ongpu) in MODES + _rng = get_default_rng(mode) + @test randn(_rng, 10, 2) != randn(_rng, 10, 2) + @test randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2) end @testset "kaiming" begin @@ -62,21 +57,19 @@ end @test x1 == x[1:5, :] @test x2 == x[6:10, :] - @inferred Lux.multigate(x, Val(2)) - run_JET_tests(Lux.multigate, x, Val(2)) + @jet Lux.multigate(x, Val(2)) x = randn(rng, 10) x1, x2 = Lux.multigate(x, Val(2)) @test x1 == x[1:5] @test x2 == x[6:10] - @inferred Lux.multigate(x, Val(2)) - run_JET_tests(Lux.multigate, x, Val(2)) + @jet Lux.multigate(x, Val(2)) end -@testset "ComponentArrays" begin +@testset "$mode: ComponentArrays" for (mode, aType, device, ongpu) in MODES ps = (weight=randn(rng, 3, 4), bias=randn(rng, 4)) p_flat, re = Optimisers.destructure(ps) ps_c = ComponentArray(ps) @@ -100,29 +93,17 @@ end println() # Optimisers - opt = Optimisers.ADAM(0.001f0) + opt = Adam(0.001f0) + ps_c = ps_c |> device st_opt = Optimisers.setup(opt, ps_c) @test_nowarn Optimisers.update(st_opt, ps_c, ps_c) @test_nowarn Optimisers.update!(st_opt, ps_c, ps_c) - - if CUDA.functional() - ps_c = ps_c |> gpu - st_opt = Optimisers.setup(opt, ps_c) - - @test_nowarn Optimisers.update(st_opt, ps_c, ps_c) - @test_nowarn Optimisers.update!(st_opt, ps_c, ps_c) - end end -@testset "_init_hidden_state" begin +@testset "$mode: _init_hidden_state" for (mode, aType, device, ongpu) in MODES rnn = RNNCell(3 => 5; init_state=Lux.zeros32) x = randn(rng, Float32, 3, 2, 2) - @test Lux._init_hidden_state(rng, rnn, view(x, :, 1, :)) == zeros(Float32, 5, 2) - - if CUDA.functional() - x = x |> gpu - @test Lux._init_hidden_state(rng, rnn, view(x, :, 1, :)) == - CUDA.zeros(Float32, 5, 2) - end + @test Lux._init_hidden_state(rng, rnn, view(device(x), :, 1, :)) == + aType(zeros(Float32, 5, 2)) end From d43e9d2364e5011c5148e4bcdc7d54208e133bc3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 25 Apr 2023 17:29:53 -0400 Subject: [PATCH 2/3] Fix CPU Testing --- Project.toml | 4 +- ext/LuxComponentArraysExt.jl | 6 + ext/LuxComponentArraysTrackerExt.jl | 17 +- ext/LuxFluxTransformExt.jl | 43 +-- ext/LuxZygoteExt.jl | 11 +- src/contrib/freeze.jl | 10 - src/layers/normalize.jl | 7 +- src/utils.jl | 11 + test/Project.toml | 3 +- test/adapt.jl | 68 ++-- test/contrib/freeze.jl | 10 +- test/contrib/map.jl | 4 +- test/contrib/share_parameters.jl | 5 +- test/contrib/training.jl | 4 +- test/core.jl | 7 +- test/ext/LuxComponentArraysExt.jl | 7 +- test/ext/LuxFluxTransformExt.jl | 94 ++--- test/layers/basic.jl | 3 +- test/layers/containers.jl | 3 +- test/layers/normalize.jl | 49 ++- test/layers/recurrent.jl | 540 +++++++++++++--------------- test/nnlib.jl | 4 +- test/test_utils.jl | 4 +- test/utils.jl | 22 +- 24 files changed, 447 insertions(+), 489 deletions(-) diff --git a/Project.toml b/Project.toml index 5309afba0e..6ab4882bd4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.51" +version = "0.4.52" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -48,7 +48,7 @@ Flux = "0.13" Functors = "0.2, 0.3, 0.4" LuxCUDA = "0.1" LuxCore = "0.1.3" -LuxLib = "0.1.7" +LuxLib = "0.2" NNlib = "0.8" Optimisers = "0.2" Requires = "1" diff --git a/ext/LuxComponentArraysExt.jl b/ext/LuxComponentArraysExt.jl index 288bd9607b..bc5502f399 100644 --- a/ext/LuxComponentArraysExt.jl +++ b/ext/LuxComponentArraysExt.jl @@ -40,6 +40,12 @@ function Lux._merge(ca::ComponentArray, p::AbstractArray) return ca end +# Empty NamedTuple: Hack to avoid breaking precompilation +function ComponentArrays.ComponentArray(data::Vector{Any}, axes::Tuple{FlatAxis}) + length(data) == 0 && return ComponentArray(Float32[], axes) + return ComponentArray{Any, 1, typeof(data), typeof(axes)}(data, axes) +end + # Parameter Sharing Lux._parameter_structure(ps::ComponentArray) = Lux._parameter_structure(NamedTuple(ps)) diff --git a/ext/LuxComponentArraysTrackerExt.jl b/ext/LuxComponentArraysTrackerExt.jl index 1b36d69cd1..0dadc3d1ff 100644 --- a/ext/LuxComponentArraysTrackerExt.jl +++ b/ext/LuxComponentArraysTrackerExt.jl @@ -8,7 +8,11 @@ else using ..Tracker end -Tracker.param(ca::ComponentArray) = ComponentArray(Tracker.param(getdata(ca)), getaxes(ca)) +function Tracker.param(ca::ComponentArray) + x = getdata(ca) + length(x) == 0 && return ComponentArray(Tracker.param(Float32[]), getaxes(ca)) + return ComponentArray(Tracker.param(x), getaxes(ca)) +end Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca)) @@ -24,4 +28,15 @@ function Base.getindex(g::Tracker.Grads, x::ComponentArray) return g[Tracker.tracker(getdata(x))] end +# For TrackedArrays ignore Base.maybeview +## Tracker with views doesn't work quite well +@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray}, + s::Symbol) where {T} + return getproperty(x, Val(s)) +end + +@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray}, v::Val) where {T} + return ComponentArrays._getindex(Base.getindex, x, v) +end + end diff --git a/ext/LuxFluxTransformExt.jl b/ext/LuxFluxTransformExt.jl index 95897609bd..1dca6efd75 100644 --- a/ext/LuxFluxTransformExt.jl +++ b/ext/LuxFluxTransformExt.jl @@ -95,13 +95,10 @@ m2(x, ps, st) ``` """ function transform(l::T; preserve_ps_st::Bool=false, kwargs...) where {T} - @warn """Transformation for type $T not implemented. Using `FluxLayer` as - a fallback.""" maxlog=1 + @warn "Transformation for type $T not implemented. Using `FluxLayer` as a fallback." maxlog=1 if !preserve_ps_st - @warn """`FluxLayer` uses the parameters and states of the `layer`. It is not - possible to NOT preserve the parameters and states. Ignoring this keyword - argument.""" maxlog=1 + @warn "`FluxLayer` uses the parameters and states of the `layer`. It is not possible to NOT preserve the parameters and states. Ignoring this keyword argument." maxlog=1 end return FluxLayer(l) @@ -168,8 +165,7 @@ function transform(l::Flux.Parallel; kwargs...) end function transform(l::Flux.PairwiseFusion; kwargs...) - @warn """Flux.PairwiseFusion and Lux.PairwiseFusion are semantically different. Using - `FluxLayer` as a fallback.""" maxlog=1 + @warn "Flux.PairwiseFusion and Lux.PairwiseFusion are semantically different. Using `FluxLayer` as a fallback." maxlog=1 return FluxLayer(l) end @@ -252,8 +248,7 @@ end transform(l::Flux.Dropout; kwargs...) = Dropout(l.p; l.dims) function transform(l::Flux.LayerNorm; kwargs...) - @warn """Flux.LayerNorm and Lux.LayerNorm are semantically different specifications. - Using `FluxLayer` as a fallback.""" maxlog=1 + @warn "Flux.LayerNorm and Lux.LayerNorm are semantically different specifications. Using `FluxLayer` as a fallback." maxlog=1 return FluxLayer(l) end @@ -273,13 +268,9 @@ function transform(l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve:: out_dims, in_dims = size(l.Wi) if preserve_ps_st if force_preserve - throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a " * - "`reset!` mechanism which hasn't been " * - "extensively tested with `FluxLayer`. Rewrite " * - "the model manually to use `RNNCell`.")) + throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `RNNCell`.")) end - @warn """Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux - and hence not supported. Ignoring these parameters.""" maxlog=1 + @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1 return RNNCell(in_dims => out_dims, l.σ; init_bias=(args...) -> copy(l.b), init_state=(args...) -> copy(l.state0)) else @@ -292,13 +283,9 @@ function transform(l::Flux.LSTMCell; preserve_ps_st::Bool=false, force_preserve: out_dims = _out_dims ÷ 4 if preserve_ps_st if force_preserve - throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a " * - "`reset!` mechanism which hasn't been " * - "extensively tested with `FluxLayer`. Rewrite " * - "the model manually to use `LSTMCell`.")) + throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `LSTMCell`.")) end - @warn """Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux - and hence not supported. Ignoring these parameters.""" maxlog=1 + @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1 bs = Lux.multigate(l.b, Val(4)) _s, _m = copy.(l.state0) return LSTMCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs), @@ -313,13 +300,9 @@ function transform(l::Flux.GRUCell; preserve_ps_st::Bool=false, force_preserve:: out_dims = _out_dims ÷ 3 if preserve_ps_st if force_preserve - throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a " * - "`reset!` mechanism which hasn't been " * - "extensively tested with `FluxLayer`. Rewrite " * - "the model manually to use `GRUCell`.")) + throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `GRUCell`.")) end - @warn """Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux - and hence not supported. Ignoring these parameters.""" maxlog=1 + @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1 bs = Lux.multigate(l.b, Val(3)) return GRUCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs), init_state=(args...) -> copy(l.state0)) @@ -333,8 +316,7 @@ function transform(l::Flux.BatchNorm; preserve_ps_st::Bool=false, if preserve_ps_st if l.track_stats force_preserve && return FluxLayer(l) - @warn """Preserving the state of `Flux.BatchNorm` is currently not supported. - Ignoring the state.""" maxlog=1 + @warn "Preserving the state of `Flux.BatchNorm` is currently not supported. Ignoring the state." maxlog=1 end if l.affine return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum, @@ -352,8 +334,7 @@ function transform(l::Flux.GroupNorm; preserve_ps_st::Bool=false, if preserve_ps_st if l.track_stats force_preserve && return FluxLayer(l) - @warn """Preserving the state of `Flux.GroupNorm` is currently not supported. - Ignoring the state.""" maxlog=1 + @warn "Preserving the state of `Flux.GroupNorm` is currently not supported. Ignoring the state." maxlog=1 end if l.affine return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, diff --git a/ext/LuxZygoteExt.jl b/ext/LuxZygoteExt.jl index 6ad1452e4c..6d4acb51bd 100644 --- a/ext/LuxZygoteExt.jl +++ b/ext/LuxZygoteExt.jl @@ -1,8 +1,15 @@ module LuxZygoteExt -isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote) +if isdefined(Base, :get_extension) + using Zygote + using Zygote: Pullback +else + using ..Zygote + using ..Zygote: Pullback +end using Adapt, LuxCUDA, Lux, Setfield +using TruncatedStacktraces: @truncate_stacktrace Adapt.adapt_storage(::Lux.LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) @@ -19,4 +26,6 @@ function Lux.Training.compute_gradients(::Lux.Training.ZygoteVJP, return grads, loss, stats, ts end +@truncate_stacktrace Pullback 1 + end diff --git a/src/contrib/freeze.jl b/src/contrib/freeze.jl index 5298777439..4b14bc67ea 100644 --- a/src/contrib/freeze.jl +++ b/src/contrib/freeze.jl @@ -86,16 +86,6 @@ function initialstates(rng::AbstractRNG, l::FrozenLayer{which_params}) where {wh return (frozen_params=(; ps_frozen...), states=st) end -_merge(nt1::NamedTuple, nt2::NamedTuple) = merge(nt1, nt2) -function _merge(p::AbstractArray, nt::NamedTuple) - @assert length(p) == 0 - return nt -end -function _merge(nt::NamedTuple, p::AbstractArray) - @assert length(p) == 0 - return nt -end - function (f::FrozenLayer)(x, ps, st::NamedTuple) y, st_ = f.layer(x, _merge(ps, st.frozen_params), st.states) st = merge(st, (; states=st_)) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 1a83ce674b..c9d89cb393 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -487,10 +487,7 @@ function initialparameters(rng::AbstractRNG, v = ps_layer[k] if k in which_params if all(iszero, v) - msg = ("Parameter $(k) is completely zero. This will result in NaN " * - "gradients. Either remove this parameter from `which_params` or " * - "modify the initialization in the actual layer. Typically this is " * - "controlled using the `init_$(k)` keyword argument.") + msg = ("Parameter $(k) is completely zero. This will result in NaN gradients. Either remove this parameter from `which_params` or modify the initialization in the actual layer. Typically this is controlled using the `init_$(k)` keyword argument.") # FIXME(@avik-pal): This is not really an ArgumentError throw(ArgumentError(msg)) end @@ -510,7 +507,7 @@ initialstates(rng::AbstractRNG, wn::WeightNorm) = initialstates(rng, wn.layer) function (wn::WeightNorm)(x, ps, st::NamedTuple) _ps = _get_normalized_parameters(wn, wn.dims, ps.normalized) - return Lux.apply(wn.layer, x, merge(_ps, ps.unnormalized), st) + return Lux.apply(wn.layer, x, _merge(_ps, ps.unnormalized), st) end @inbounds @generated function _get_normalized_parameters(::WeightNorm{which_params}, diff --git a/src/utils.jl b/src/utils.jl index 6ce7534016..6c128b178f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -288,3 +288,14 @@ in the backward pass. """ @inline foldl_init(op, x) = foldl_init(op, x, nothing) @inline foldl_init(op, x, init) = foldl(op, x; init) + +# Merging Exotic Types +_merge(nt1::NamedTuple, nt2::NamedTuple) = merge(nt1, nt2) +function _merge(p::AbstractArray, nt::NamedTuple) + @assert length(p) == 0 + return nt +end +function _merge(nt::NamedTuple, p::AbstractArray) + @assert length(p) == 0 + return nt +end diff --git a/test/Project.toml b/test/Project.toml index 9f5442ff20..13806ceba8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,8 +3,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" @@ -15,6 +13,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/adapt.jl b/test/adapt.jl index efa7b3f064..02b00a7148 100644 --- a/test/adapt.jl +++ b/test/adapt.jl @@ -1,48 +1,34 @@ -using Lux, Functors, Random, Test -import LuxCUDA -import LuxCUDA.CUDA +using Lux, Functors, Test, LuxCUDA -if LuxCUDA.functional() - using LuxCUDA.CUDA # exports CuArray, etc - @info "starting CUDA tests" -else - @info "CUDA not functional, testing via JLArrays" - using JLArrays - JLArrays.allowscalar(false) - - # JLArrays provides a fake GPU array, for testing - using Random, Adapt - CUDA.cu(x) = jl(x) - CuArray{T, N} = JLArray{T, N} +include("test_utils.jl") - function Lux.gpu(x) - return fmap(x -> adapt(Lux.LuxCUDAAdaptor(), x), x; exclude=Lux._isleaf) - end -end +CUDA.allowscalar(false) -@testset "Device Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", - rng=Random.default_rng()) +if LuxCUDA.functional() + @testset "Device Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + rng=get_stable_rng(12345)) - ps_gpu = ps |> gpu - @test ps_gpu.a.c isa CuArray - @test ps_gpu.b isa CuArray - @test ps_gpu.a.d == ps.a.d - @test ps_gpu.e == ps.e - @test ps_gpu.d == ps.d - @test ps_gpu.rng == ps.rng + ps_gpu = ps |> gpu + @test ps_gpu.a.c isa CuArray + @test ps_gpu.b isa CuArray + @test ps_gpu.a.d == ps.a.d + @test ps_gpu.e == ps.e + @test ps_gpu.d == ps.d + @test ps_gpu.rng == ps.rng - ps_cpu = ps_gpu |> cpu - @test ps_cpu.a.c isa Array - @test ps_cpu.b isa Array - @test ps_cpu.a.c == ps.a.c - @test ps_cpu.b == ps.b - @test ps_cpu.a.d == ps.a.d - @test ps_cpu.e == ps.e - @test ps_cpu.d == ps.d - @test ps_cpu.rng == ps.rng + ps_cpu = ps_gpu |> cpu + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng == ps.rng - # Deprecated Functionality (Remove in v0.5) - @test_deprecated cpu(Dense(10, 10)) - @test_deprecated gpu(Dense(10, 10)) + # Deprecated Functionality (Remove in v0.5) + @test_deprecated cpu(Dense(10, 10)) + @test_deprecated gpu(Dense(10, 10)) + end end diff --git a/test/contrib/freeze.jl b/test/contrib/freeze.jl index 153e541592..1c51fcb34b 100644 --- a/test/contrib/freeze.jl +++ b/test/contrib/freeze.jl @@ -1,9 +1,8 @@ -using ComponentArrays, Lux, Random, Test +using ComponentArrays, Lux, Test include("../test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = get_stable_rng(12345) @testset "$mode: All Parameters Freezing" for (mode, aType, device, ongpu) in MODES @testset "NamedTuple" begin @@ -22,6 +21,7 @@ Random.seed!(rng, 0) @jet fd(x, ps, st) __f = (x, ps) -> sum(first(fd(x, ps, st))) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @@ -34,8 +34,8 @@ Random.seed!(rng, 0) @test m(x, ps, st)[1] == m(x, ps_c, st)[1] @jet m(x, ps_c, st) - # __f = (x, ps) -> sum(first(m(x, ps, st))) - # @eval @test_gradients $__f $x $ps_c atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_tracker=true + __f = (x, ps) -> sum(first(m(x, ps, st))) + @eval @test_gradients $__f $x $ps_c atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end end diff --git a/test/contrib/map.jl b/test/contrib/map.jl index 707a45769d..3ba3fbd773 100644 --- a/test/contrib/map.jl +++ b/test/contrib/map.jl @@ -1,4 +1,4 @@ -using Lux, Random, Setfield, Test +using Lux, Setfield, Test include("../test_utils.jl") @@ -31,7 +31,7 @@ end chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), dense_3=Dense(5 => 1)) - rng = Random.default_rng() + rng = get_stable_rng(12345) ps, st = Lux.setup(rng, c) .|> device c_, ps_, st_ = Lux.layer_map(zero_dense_params_1, c, ps, st) diff --git a/test/contrib/share_parameters.jl b/test/contrib/share_parameters.jl index 5abd3da401..815e409aff 100644 --- a/test/contrib/share_parameters.jl +++ b/test/contrib/share_parameters.jl @@ -1,9 +1,8 @@ -using ComponentArrays, Lux, Random, Test +using ComponentArrays, Lux, Test include("../test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = get_stable_rng(12345) @testset "$mode" for (mode, aType, device, ongpu) in MODES model = Chain(; d1=Dense(2 => 4, tanh), d2=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), diff --git a/test/contrib/training.jl b/test/contrib/training.jl index feb1a621f1..ed88822afe 100644 --- a/test/contrib/training.jl +++ b/test/contrib/training.jl @@ -8,7 +8,7 @@ function _loss_function(model, ps, st, data) end @testset "$mode: TrainState" for (mode, aType, device, ongpu) in MODES - rng = MersenneTwister(0) + rng = get_stable_rng(12345) model = Dense(3, 2) opt = Adam(0.01f0) @@ -29,7 +29,7 @@ end end @testset "$mode: AbstractVJP" for (mode, aType, device, ongpu) in MODES - rng = MersenneTwister(0) + rng = get_stable_rng(12345) model = Dense(3, 2) opt = Adam(0.01f0) diff --git a/test/core.jl b/test/core.jl index 93ca0ca458..fa816bcb50 100644 --- a/test/core.jl +++ b/test/core.jl @@ -1,7 +1,8 @@ -using Functors, Lux, Random, Test +using Functors, Lux, Test -rng = Random.default_rng() -Random.seed!(rng, 0) +include("test_utils.jl") + +rng = get_stable_rng(12345) @testset "AbstractExplicitLayer Interface" begin # Deprecated Functionality (Remove in v0.5) diff --git a/test/ext/LuxComponentArraysExt.jl b/test/ext/LuxComponentArraysExt.jl index 4ab4376cb5..a482a716d0 100644 --- a/test/ext/LuxComponentArraysExt.jl +++ b/test/ext/LuxComponentArraysExt.jl @@ -1,7 +1,8 @@ -using ComponentArrays, Lux, Random, Test, Zygote +using ComponentArrays, Lux, Test, Zygote -rng = Random.default_rng() -Random.seed!(rng, 0) +include("../test_utils.jl") + +rng = get_stable_rng(12345) @testset "LuxComponentArraysExt" begin # Ref: https://github.com/avik-pal/Lux.jl/issues/243 diff --git a/test/ext/LuxFluxTransformExt.jl b/test/ext/LuxFluxTransformExt.jl index 9e8387a184..becbe3a320 100644 --- a/test/ext/LuxFluxTransformExt.jl +++ b/test/ext/LuxFluxTransformExt.jl @@ -1,5 +1,5 @@ import Flux -using Lux, Random, Test +using Lux, Test fdevice(::typeof(cpu)) = Flux.cpu fdevice(::typeof(gpu)) = Flux.gpu @@ -13,12 +13,12 @@ include("../test_utils.jl") x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (1, 1) end @@ -28,12 +28,12 @@ include("../test_utils.jl") x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (5, 1) end @@ -43,12 +43,12 @@ include("../test_utils.jl") x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (2, 1) end @@ -59,12 +59,12 @@ include("../test_utils.jl") x = rand(Float32, 2, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (2, 1) end @@ -75,12 +75,12 @@ include("../test_utils.jl") x = (rand(Float32, 2, 1), rand(Float32, 2, 1)) .|> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test all(model(x) .≈ model_lux(x, ps, st)[1]) model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test all(size.(model_lux(x, ps, st)[1]) .== ((2, 1),)) end @@ -94,12 +94,12 @@ include("../test_utils.jl") x = randn(Float32, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == size(model(x)) end end @@ -111,12 +111,12 @@ include("../test_utils.jl") x = randn(Float32, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == size(model(x)) end end @@ -129,12 +129,12 @@ include("../test_utils.jl") y = randn(Float32, 3, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x, y) ≈ model_lux((x, y), ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux((x, y), ps, st)[1]) == size(model(x, y)) end end @@ -144,12 +144,12 @@ include("../test_utils.jl") x = rand(1:16, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (4, 2, 4) end @@ -161,7 +161,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] @@ -169,7 +169,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -179,7 +179,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] @@ -187,7 +187,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -197,7 +197,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] @@ -206,7 +206,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -218,7 +218,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -228,7 +228,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -238,7 +238,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -248,7 +248,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -258,7 +258,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -268,7 +268,7 @@ include("../test_utils.jl") x = rand(Float32, 6, 6, 1, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -280,7 +280,7 @@ include("../test_utils.jl") x = rand(Float32, 2, 2, 2, 1) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (10, 10, 2, 1) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -291,7 +291,7 @@ include("../test_utils.jl") x = randn(Float32, 2, 2, 4, 1) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1]) == (4, 4, 1, 1) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -306,7 +306,7 @@ include("../test_utils.jl") x = rand(Float32, 2, 4) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @@ -316,7 +316,7 @@ include("../test_utils.jl") x = rand(Float32, 2, 4) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @@ -326,7 +326,7 @@ include("../test_utils.jl") x = rand(Float32, 2, 4) |> aType model_lux = transform(model) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @@ -338,7 +338,7 @@ include("../test_utils.jl") x = randn(Float32, 2, 4) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -348,7 +348,7 @@ include("../test_utils.jl") @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model; preserve_ps_st=true, force_preserve=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -359,13 +359,13 @@ include("../test_utils.jl") x = randn(Float32, 2, 2, 4, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] model_lux = transform(model; preserve_ps_st=true, force_preserve=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -376,7 +376,7 @@ include("../test_utils.jl") x = randn(Float32, 4, 4, 4, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device st = Lux.testmode(st) @test model(x) ≈ model_lux(x, ps, st)[1] @@ -387,7 +387,7 @@ include("../test_utils.jl") x = randn(Float32, 4, 4, 4, 1) |> aType model_lux = transform(model; preserve_ps_st=true) - ps, st = Lux.setup(Random.default_rng(), model_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model_lux) .|> device @test model(x) ≈ model_lux(x, ps, st)[1] end @@ -398,12 +398,12 @@ include("../test_utils.jl") model = transform(Flux.Dropout(0.5f0)) x = randn(Float32, 2, 4) |> aType - ps, st = Lux.setup(Random.default_rng(), model) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device @test size(model(x, ps, st)[1]) == size(x) x = randn(Float32, 2, 3, 4) |> aType - ps, st = Lux.setup(Random.default_rng(), model) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device @test size(model(x, ps, st)[1]) == size(x) end @@ -412,12 +412,12 @@ include("../test_utils.jl") model = transform(Flux.AlphaDropout(0.5)) x = randn(Float32, 2, 4) |> aType - ps, st = Lux.setup(Random.default_rng(), model) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device @test size(model(x, ps, st)[1]) == size(x) x = randn(Float32, 2, 4, 3) |> aType - ps, st = Lux.setup(Random.default_rng(), model) .|> device + ps, st = Lux.setup(get_stable_rng(12345), model) .|> device @test size(model(x, ps, st)[1]) == size(x) end @@ -437,7 +437,7 @@ include("../test_utils.jl") x = randn(10) |> aType c_lux = transform(c) - ps, st = Lux.setup(Random.default_rng(), c_lux) .|> device + ps, st = Lux.setup(get_stable_rng(12345), c_lux) .|> device @test c(x) ≈ c_lux(x, ps, st)[1] end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 2ecc32dd5b..81d6cb035e 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -2,8 +2,7 @@ using Lux, NNlib, Random, Test include("../test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = get_stable_rng(12345) @testset "$mode: Miscellaneous Layers" for (mode, aType, device, ongpu) in MODES @testset "Reshape Layer" begin diff --git a/test/layers/containers.jl b/test/layers/containers.jl index 312ca880ad..11c9aa796a 100644 --- a/test/layers/containers.jl +++ b/test/layers/containers.jl @@ -2,8 +2,7 @@ using Lux, NNlib, Random, Test, Zygote include("../test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = get_stable_rng(12345) @testset "$mode: SkipConnection" for (mode, aType, device, ongpu) in MODES @testset "zero sum" begin diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index a70fc4ab2b..d73e2617b8 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -1,9 +1,8 @@ -using Lux, NNlib, Random, Statistics, Zygote +using Lux, NNlib, Statistics, Zygote include("../test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = get_stable_rng(12345) @testset "$mode: BatchNorm" for (mode, aType, device, ongpu) in MODES m = BatchNorm(2) @@ -49,7 +48,7 @@ Random.seed!(rng, 0) @jet m(x, ps, st) __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true for affine in (true, false) m = BatchNorm(2; affine, track_stats=false) @@ -61,10 +60,10 @@ Random.seed!(rng, 0) if affine __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end # with activation function @@ -83,10 +82,10 @@ Random.seed!(rng, 0) if affine __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end m = BatchNorm(32; affine) @@ -171,7 +170,7 @@ end for affine in (true, false) m = GroupNorm(2, 2; affine, track_stats=false) - x = randn(rng, Float32, 3, 2, 1) |> aType + x = rand(rng, Float32, 3, 2, 1) |> aType display(m) ps, st = Lux.setup(rng, m) .|> device @@ -179,10 +178,10 @@ end if affine __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-2 rtol=1.0f-2 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-2 rtol=1.0f-2 gpu_testing=$ongpu skip_finite_differences=true end # with activation function @@ -197,10 +196,10 @@ end if affine __f = (x, ps) -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(m(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end m = GroupNorm(32, 16; affine) @@ -254,7 +253,7 @@ end @jet wn(x, ps, st) __f = ps -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true wn = WeightNorm(c, (:weight,)) display(wn) @@ -263,7 +262,7 @@ end @jet wn(x, ps, st) __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true wn = WeightNorm(c, (:weight, :bias), (2, 2)) display(wn) @@ -272,7 +271,7 @@ end @jet wn(x, ps, st) __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true wn = WeightNorm(c, (:weight,), (2,)) display(wn) @@ -281,7 +280,7 @@ end @jet wn(x, ps, st) __f = (x, ps) -> sum(first(wn(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_reverse_diff=true end @testset "Dense" begin @@ -362,10 +361,10 @@ end if affine __f = (x, ps) -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end for act in (sigmoid, tanh) @@ -379,10 +378,10 @@ end if affine __f = (x, ps) -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(ln(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end end end @@ -411,10 +410,10 @@ end if affine __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end for act in (sigmoid, tanh) @@ -428,10 +427,10 @@ end if affine __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true else __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end end end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 626499d8b3..3420d836ac 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,248 +1,247 @@ -using Lux, NNlib, Random, Test +using Lux, NNlib, Test include("../test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) - -@testset "$mode: RNNCell" for (mode, aType, device, ongpu) in MODES - for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), - RNNCell(3 => 5, tanh; use_bias=false), - RNNCell(3 => 5, identity; use_bias=false), - RNNCell(3 => 5, identity; use_bias=false, train_state=false)) - display(rnncell) - ps, st = Lux.setup(rng, rnncell) .|> device - x = randn(rng, Float32, 3, 2) |> aType - (y, carry), st_ = Lux.apply(rnncell, x, ps, st) - - @jet rnncell(x, ps, st) - @jet rnncell((x, carry), ps, st) - - function loss_loop_rnncell(p) - (y, carry), st_ = rnncell(x, p, st) - for i in 1:10 - (y, carry), st_ = rnncell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @test_throws ErrorException ps.train_state - - @eval @test_gradients $loss_loop_rnncell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - end - - @testset "Trainable hidden states" begin for rnncell in (RNNCell(3 => 5, identity; - use_bias=false, - train_state=true), - RNNCell(3 => 5, identity; - use_bias=true, - train_state=true)) - rnn_no_trainable_state = RNNCell(3 => 5, identity; use_bias=false, - train_state=false) - x = randn(rng, Float32, 3, 2) |> aType - _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> device - (_y, _carry), _ = Lux.apply(rnn_no_trainable_state, x, _ps, _st) - - rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, rnncell) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state,)) - (y, carry), _ = Lux.apply(rnncell, x, ps, st) - @test carry == _carry - - l, back = Zygote.pullback(p -> sum(abs2, 0 .- rnncell(x, p, st)[1][1]), ps) - gs = back(one(l))[1] - @test !isnothing(gs.hidden_state) - end end - - # Deprecated Functionality (Remove in v0.5) - @testset "Deprecations" begin - @test_deprecated RNNCell(3 => 5, relu; bias=false) - @test_deprecated RNNCell(3 => 5, relu; bias=true) - @test_throws ArgumentError RNNCell(3 => 5, relu; bias=false, use_bias=false) - end -end - -@testset "$mode: LSTMCell" for (mode, aType, device, ongpu) in MODES - for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), - LSTMCell(3 => 5; use_bias=false)) - display(lstmcell) - ps, st = Lux.setup(rng, lstmcell) .|> device - x = randn(rng, Float32, 3, 2) |> aType - (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) - - @jet lstmcell(x, ps, st) - @jet lstmcell((x, carry), ps, st) - - function loss_loop_lstmcell(p) - (y, carry), st_ = lstmcell(x, p, st) - for i in 1:10 - (y, carry), st_ = lstmcell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - - @test_throws ErrorException ps.train_state - @test_throws ErrorException ps.train_memory - end - - @testset "Trainable hidden states" begin - x = randn(rng, Float32, 3, 2) |> aType - _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) - _ps, _st = Lux.setup(rng, _lstm) .|> device - (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> device - ps = _ps - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test_throws ErrorException gs.hidden_state - @test_throws ErrorException gs.memory - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state,)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test !isnothing(gs.hidden_state) - @test_throws ErrorException gs.memory - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (memory=ps.memory,)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test_throws ErrorException gs.hidden_state - @test !isnothing(gs.memory) - - lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test !isnothing(gs.hidden_state) - @test !isnothing(gs.memory) - - lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> device - ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) - (y, carry), _ = Lux.apply(lstm, x, ps, st) - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test !isnothing(gs.bias) - @test !isnothing(gs.hidden_state) - @test !isnothing(gs.memory) - end -end - -@testset "$mode: GRUCell" for (mode, aType, device, ongpu) in MODES - for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), - GRUCell(3 => 5; use_bias=false)) - display(grucell) - ps, st = Lux.setup(rng, grucell) .|> device - x = randn(rng, Float32, 3, 2) |> aType - (y, carry), st_ = Lux.apply(grucell, x, ps, st) - - @jet grucell(x, ps, st) - @jet grucell((x, carry), ps, st) - - function loss_loop_grucell(p) - (y, carry), st_ = grucell(x, p, st) - for i in 1:10 - (y, carry), st_ = grucell((x, carry), p, st_) - end - return sum(abs2, y) - end - - @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - - @test_throws ErrorException ps.train_state - end - - @testset "Trainable hidden states" begin - x = randn(rng, Float32, 3, 2) |> aType - _gru = GRUCell(3 => 5; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, _gru) .|> device - (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) - - gru = GRUCell(3 => 5; use_bias=false, train_state=false) - ps, st = Lux.setup(rng, gru) .|> device - ps = _ps - (y, carry), _ = Lux.apply(gru, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test_throws ErrorException gs.bias - @test_throws ErrorException gs.hidden_state - - gru = GRUCell(3 => 5; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, gru) .|> device - ps = merge(_ps, (hidden_state=ps.hidden_state,)) - (y, carry), _ = Lux.apply(gru, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test !isnothing(gs.hidden_state) - - gru = GRUCell(3 => 5; use_bias=true, train_state=true) - ps, st = Lux.setup(rng, gru) .|> device - ps = merge(_ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) - (y, carry), _ = Lux.apply(gru, x, ps, st) - @test carry == _carry - l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) - gs = back(one(l))[1] - @test !isnothing(gs.hidden_state) - end -end - -@testset "$mode: StatefulRecurrentCell" for (mode, aType, device, ongpu) in MODES - for _cell in (RNNCell, LSTMCell, GRUCell), - use_bias in (true, false), - train_state in (true, false) - - cell = _cell(3 => 5; use_bias, train_state) - rnn = StatefulRecurrentCell(cell) - display(rnn) - x = randn(rng, Float32, 3, 2) |> aType - ps, st = Lux.setup(rng, rnn) .|> device - - y, st_ = rnn(x, ps, st) - - @jet rnn(x, ps, st) - @jet rnn(x, ps, st_) - - @test size(y) == (5, 2) - @test st.carry === nothing - @test st_.carry !== nothing - - st__ = Lux.update_state(st, :carry, nothing) - @test st__.carry === nothing - - function loss_loop_rnn(p) - y, st_ = rnn(x, p, st) - for i in 1:10 - y, st_ = rnn(x, p, st_) - end - return sum(abs2, y) - end - - @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - end -end +rng = get_stable_rng(12345) + +# @testset "$mode: RNNCell" for (mode, aType, device, ongpu) in MODES +# for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), +# RNNCell(3 => 5, tanh; use_bias=false), +# RNNCell(3 => 5, identity; use_bias=false), +# RNNCell(3 => 5, identity; use_bias=false, train_state=false)) +# display(rnncell) +# ps, st = Lux.setup(rng, rnncell) .|> device +# x = randn(rng, Float32, 3, 2) |> aType +# (y, carry), st_ = Lux.apply(rnncell, x, ps, st) + +# @jet rnncell(x, ps, st) +# @jet rnncell((x, carry), ps, st) + +# function loss_loop_rnncell(p) +# (y, carry), st_ = rnncell(x, p, st) +# for i in 1:10 +# (y, carry), st_ = rnncell((x, carry), p, st_) +# end +# return sum(abs2, y) +# end + +# @test_throws ErrorException ps.train_state + +# @eval @test_gradients $loss_loop_rnncell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu +# end + +# @testset "Trainable hidden states" begin for rnncell in (RNNCell(3 => 5, identity; +# use_bias=false, +# train_state=true), +# RNNCell(3 => 5, identity; +# use_bias=true, +# train_state=true)) +# rnn_no_trainable_state = RNNCell(3 => 5, identity; use_bias=false, +# train_state=false) +# x = randn(rng, Float32, 3, 2) |> aType +# _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> device +# (_y, _carry), _ = Lux.apply(rnn_no_trainable_state, x, _ps, _st) + +# rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) +# ps, st = Lux.setup(rng, rnncell) .|> device +# ps = merge(_ps, (hidden_state=ps.hidden_state,)) +# (y, carry), _ = Lux.apply(rnncell, x, ps, st) +# @test carry == _carry + +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- rnncell(x, p, st)[1][1]), ps) +# gs = back(one(l))[1] +# @test !isnothing(gs.hidden_state) +# end end + +# # Deprecated Functionality (Remove in v0.5) +# @testset "Deprecations" begin +# @test_deprecated RNNCell(3 => 5, relu; bias=false) +# @test_deprecated RNNCell(3 => 5, relu; bias=true) +# @test_throws ArgumentError RNNCell(3 => 5, relu; bias=false, use_bias=false) +# end +# end + +# @testset "$mode: LSTMCell" for (mode, aType, device, ongpu) in MODES +# for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), +# LSTMCell(3 => 5; use_bias=false)) +# display(lstmcell) +# ps, st = Lux.setup(rng, lstmcell) .|> device +# x = randn(rng, Float32, 3, 2) |> aType +# (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) + +# @jet lstmcell(x, ps, st) +# @jet lstmcell((x, carry), ps, st) + +# function loss_loop_lstmcell(p) +# (y, carry), st_ = lstmcell(x, p, st) +# for i in 1:10 +# (y, carry), st_ = lstmcell((x, carry), p, st_) +# end +# return sum(abs2, y) +# end + +# @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + +# @test_throws ErrorException ps.train_state +# @test_throws ErrorException ps.train_memory +# end + +# @testset "Trainable hidden states" begin +# x = randn(rng, Float32, 3, 2) |> aType +# _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) +# _ps, _st = Lux.setup(rng, _lstm) .|> device +# (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) + +# lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) +# ps, st = Lux.setup(rng, lstm) .|> device +# ps = _ps +# (y, carry), _ = Lux.apply(lstm, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test_throws ErrorException gs.bias +# @test_throws ErrorException gs.hidden_state +# @test_throws ErrorException gs.memory + +# lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) +# ps, st = Lux.setup(rng, lstm) .|> device +# ps = merge(_ps, (hidden_state=ps.hidden_state,)) +# (y, carry), _ = Lux.apply(lstm, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test_throws ErrorException gs.bias +# @test !isnothing(gs.hidden_state) +# @test_throws ErrorException gs.memory + +# lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) +# ps, st = Lux.setup(rng, lstm) .|> device +# ps = merge(_ps, (memory=ps.memory,)) +# (y, carry), _ = Lux.apply(lstm, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test_throws ErrorException gs.bias +# @test_throws ErrorException gs.hidden_state +# @test !isnothing(gs.memory) + +# lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) +# ps, st = Lux.setup(rng, lstm) .|> device +# ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) +# (y, carry), _ = Lux.apply(lstm, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test_throws ErrorException gs.bias +# @test !isnothing(gs.hidden_state) +# @test !isnothing(gs.memory) + +# lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) +# ps, st = Lux.setup(rng, lstm) .|> device +# ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) +# (y, carry), _ = Lux.apply(lstm, x, ps, st) +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test !isnothing(gs.bias) +# @test !isnothing(gs.hidden_state) +# @test !isnothing(gs.memory) +# end +# end + +# @testset "$mode: GRUCell" for (mode, aType, device, ongpu) in MODES +# for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), +# GRUCell(3 => 5; use_bias=false)) +# display(grucell) +# ps, st = Lux.setup(rng, grucell) .|> device +# x = randn(rng, Float32, 3, 2) |> aType +# (y, carry), st_ = Lux.apply(grucell, x, ps, st) + +# @jet grucell(x, ps, st) +# @jet grucell((x, carry), ps, st) + +# function loss_loop_grucell(p) +# (y, carry), st_ = grucell(x, p, st) +# for i in 1:10 +# (y, carry), st_ = grucell((x, carry), p, st_) +# end +# return sum(abs2, y) +# end + +# @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + +# @test_throws ErrorException ps.train_state +# end + +# @testset "Trainable hidden states" begin +# x = randn(rng, Float32, 3, 2) |> aType +# _gru = GRUCell(3 => 5; use_bias=false, train_state=false) +# _ps, _st = Lux.setup(rng, _gru) .|> device +# (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) + +# gru = GRUCell(3 => 5; use_bias=false, train_state=false) +# ps, st = Lux.setup(rng, gru) .|> device +# ps = _ps +# (y, carry), _ = Lux.apply(gru, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test_throws ErrorException gs.bias +# @test_throws ErrorException gs.hidden_state + +# gru = GRUCell(3 => 5; use_bias=false, train_state=true) +# ps, st = Lux.setup(rng, gru) .|> device +# ps = merge(_ps, (hidden_state=ps.hidden_state,)) +# (y, carry), _ = Lux.apply(gru, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test !isnothing(gs.hidden_state) + +# gru = GRUCell(3 => 5; use_bias=true, train_state=true) +# ps, st = Lux.setup(rng, gru) .|> device +# ps = merge(_ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) +# (y, carry), _ = Lux.apply(gru, x, ps, st) +# @test carry == _carry +# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) +# gs = back(one(l))[1] +# @test !isnothing(gs.hidden_state) +# end +# end + +# @testset "$mode: StatefulRecurrentCell" for (mode, aType, device, ongpu) in MODES +# for _cell in (RNNCell, LSTMCell, GRUCell), +# use_bias in (true, false), +# train_state in (true, false) + +# cell = _cell(3 => 5; use_bias, train_state) +# rnn = StatefulRecurrentCell(cell) +# display(rnn) +# x = randn(rng, Float32, 3, 2) |> aType +# ps, st = Lux.setup(rng, rnn) .|> device + +# y, st_ = rnn(x, ps, st) + +# @jet rnn(x, ps, st) +# @jet rnn(x, ps, st_) + +# @test size(y) == (5, 2) +# @test st.carry === nothing +# @test st_.carry !== nothing + +# st__ = Lux.update_state(st, :carry, nothing) +# @test st__.carry === nothing + +# function loss_loop_rnn(p) +# y, st_ = rnn(x, p, st) +# for i in 1:10 +# y, st_ = rnn(x, p, st_) +# end +# return sum(abs2, y) +# end + +# @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu +# end +# end @testset "$mode: Recurrence" for (mode, aType, device, ongpu) in MODES for _cell in (RNNCell, LSTMCell, GRUCell), @@ -255,10 +254,9 @@ end display(rnn) # Batched Time Series - for x in (randn(rng, Float32, 3, 4, 2), - Tuple(randn(rng, Float32, 3, 2) for _ in 1:4), - [randn(rng, Float32, 3, 2) for _ in 1:4]) - x = x |> aType + for x in (randn(rng, Float32, 3, 4, 2) |> aType, + Tuple(randn(rng, Float32, 3, 2) for _ in 1:4) .|> aType, + [randn(rng, Float32, 3, 2) for _ in 1:4] .|> aType) ps, st = Lux.setup(rng, rnn) .|> device y, st_ = rnn(x, ps, st) y_, st__ = rnn_seq(x, ps, st) @@ -278,56 +276,12 @@ end end end -@testset "$mode: Recurrence" for (mode, aType, device, ongpu) in MODES - for _cell in (RNNCell, LSTMCell, GRUCell), - use_bias in (true, false), - train_state in (true, false) - - cell = _cell(3 => 5; use_bias, train_state) - rnn = Recurrence(cell) - rnn_seq = Recurrence(cell; return_sequence=true) - display(rnn) - - # Batched Time Series - for x in (randn(rng, Float32, 3, 4, 2), - Tuple(randn(rng, Float32, 3, 2) for _ in 1:4), - [randn(rng, Float32, 3, 2) for _ in 1:4]) - x = x |> aType - ps, st = Lux.setup(rng, rnn) .|> device - y, st_ = rnn(x, ps, st) - y_, st__ = rnn_seq(x, ps, st) - - @jet rnn(x, ps, st) - @jet rnn_seq(x, ps, st) - - @test size(y) == (5, 2) - @test length(y_) == 4 - @test all(x -> size(x) == (5, 2), y_) - - __f = p -> sum(first(rnn(x, p, st))) - @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - - __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) - @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - end - end - # Ordering Check: https://github.com/LuxDL/Lux.jl/issues/302 encoder = Recurrence(RNNCell(1 => 1, identity; init_weight=ones, init_state=zeros, init_bias=zeros); return_sequence=true) - ps, st = Lux.setup(rng, encoder) - m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) + ps, st = Lux.setup(rng, encoder) .|> device + m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) |> aType res, _ = encoder(m2, ps, st) - @test vec(reduce(vcat, res)) ≈ [0.5, 0.5, 1.2, 2.0] -end - -@testset "multigate" begin - x = rand(6, 5) - res, (dx,) = Zygote.withgradient(x) do x - x1, _, x3 = Lux.multigate(x, Val(3)) - return sum(x1) + sum(x3 .* 2) - end - @test res == sum(x[1:2, :]) + 2sum(x[5:6, :]) - @test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)] + @test Array(vec(reduce(vcat, res))) ≈ [0.5, 0.5, 1.2, 2.0] end diff --git a/test/nnlib.jl b/test/nnlib.jl index 3aa68351ae..9b225cacc7 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -2,9 +2,9 @@ using Lux, Random, Test include("test_utils.jl") +rng = get_stable_rng(12345) + @testset "$mode: Elementwise Operation Dispatches" for (mode, aType, device, ongpu) in MODES - rng = Random.default_rng() - Random.seed!(rng, 0) custom_activation(x) = abs(x) for T in [Float64, Float32, ComplexF64, ComplexF32] diff --git a/test/test_utils.jl b/test/test_utils.jl index 8db8f55868..9c028f20d6 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,4 +1,4 @@ -using Lux, LuxCore, LuxLib, LuxTestUtils, Test, Zygote +using Lux, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote using LuxCUDA # CUDA Support using LuxTestUtils: @jet, @test_gradients, check_approx @@ -29,3 +29,5 @@ function get_default_rng(mode::String) error("Unknown mode: $mode") end end + +get_stable_rng(seed=12345) = StableRNG(seed) diff --git a/test/utils.jl b/test/utils.jl index 303dbd9a78..64cbfe6b29 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,10 +1,9 @@ -using Lux, ComponentArrays, LuxCUDA, Functors, Random, Optimisers, Zygote, Test +using Lux, ComponentArrays, LuxCUDA, Functors, Optimisers, Zygote, Test using Statistics: std include("test_utils.jl") -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = get_stable_rng(12345) @testset "_nfan" begin # Fallback @@ -51,8 +50,8 @@ end @test_throws MethodError Lux.istraining((training=true,)) end -@testset "multigate" begin - x = randn(rng, 10, 1) +@testset "$mode: multigate" for (mode, aType, device, ongpu) in MODES + x = randn(rng, 10, 1) |> aType x1, x2 = Lux.multigate(x, Val(2)) @test x1 == x[1:5, :] @@ -60,13 +59,24 @@ end @jet Lux.multigate(x, Val(2)) - x = randn(rng, 10) + x = randn(rng, 10) |> aType x1, x2 = Lux.multigate(x, Val(2)) @test x1 == x[1:5] @test x2 == x[6:10] @jet Lux.multigate(x, Val(2)) + + x = rand(6, 5) |> aType + res, (dx,) = Zygote.withgradient(x) do x + x1, _, x3 = Lux.multigate(x, Val(3)) + return sum(x1) + sum(x3 .* 2) + end + + @jet Lux.multigate(x, Val(3)) + + @test res == sum(x[1:2, :]) + 2sum(x[5:6, :]) + @test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)] end @testset "$mode: ComponentArrays" for (mode, aType, device, ongpu) in MODES From 664f40ac4d813e0a504cc80a9d668b557e1e9099 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Apr 2023 16:23:55 -0400 Subject: [PATCH 3/3] Fix GPU tests --- ext/LuxTrackerExt.jl | 6 +- src/chainrules.jl | 2 +- src/utils.jl | 5 +- test/contrib/freeze.jl | 6 +- test/contrib/share_parameters.jl | 28 +-- test/layers/containers.jl | 5 +- test/layers/conv.jl | 25 ++- test/layers/dropout.jl | 3 +- test/layers/normalize.jl | 4 +- test/layers/recurrent.jl | 332 +++++++++++++++---------------- test/test_utils.jl | 2 + test/utils.jl | 7 +- 12 files changed, 220 insertions(+), 205 deletions(-) diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index a2179f0677..000f8e9f52 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -1,7 +1,7 @@ module LuxTrackerExt isdefined(Base, :get_extension) ? (using Tracker) : (using ..Tracker) -using Functors, Lux, Setfield +using ChainRulesCore, Functors, Lux, Setfield # Type Piracy: Need to upstream Tracker.param(nt::NamedTuple) = fmap(Tracker.param, nt) @@ -18,6 +18,10 @@ Tracker.data(t::Tuple) = map(Tracker.data, t) # Weight Norm Patch @inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) +# multigate chain rules +@inline Lux._gate(x::TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] +@inline Lux._gate(x::TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] + # Lux.Training function Lux.Training.compute_gradients(::Lux.Training.TrackerVJP, objective_function::Function, data, diff --git a/src/chainrules.jl b/src/chainrules.jl index 06e6eb3f33..16aa3dc163 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -75,7 +75,7 @@ function CRC.rrule(::typeof(multigate), x::AbstractArray, c::Val{N}) where {N} dyᵢ isa AbstractZero && return @. dxᵢ += dyᵢ end - return (NoTangent(), dx, NoTangent(), NoTangent()) + return (NoTangent(), dx, NoTangent()) end return multigate(x, c), multigate_pullback end diff --git a/src/utils.jl b/src/utils.jl index 6c128b178f..7c78e4c735 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -198,7 +198,10 @@ end Split up `x` into `N` equally sized chunks (along dimension `1`). """ -@inline multigate(x::AbstractArray, ::Val{N}) where {N} = _gate.((x,), size(x, 1) ÷ N, 1:N) +@inline function multigate(x::AbstractArray, ::Val{N}) where {N} + # return map(i -> _gate(x, size(x, 1) ÷ N, i), 1:N) + return ntuple(i -> _gate(x, size(x, 1) ÷ N, i), N) +end # Val utilities get_known(::Val{T}) where {T} = T diff --git a/test/contrib/freeze.jl b/test/contrib/freeze.jl index 1c51fcb34b..d58ef3dc46 100644 --- a/test/contrib/freeze.jl +++ b/test/contrib/freeze.jl @@ -27,8 +27,10 @@ rng = get_stable_rng(12345) @testset "ComponentArray" begin m = Chain(Lux.freeze(Dense(1 => 3, tanh)), Dense(3 => 1)) - ps, st = Lux.setup(rng, m) .|> device - ps_c = ComponentVector(ps) + ps, st = Lux.setup(rng, m) + st = st |> device + ps_c = ComponentVector(ps) |> device + ps = ps |> device x = randn(rng, Float32, 1, 2) |> aType @test m(x, ps, st)[1] == m(x, ps_c, st)[1] diff --git a/test/contrib/share_parameters.jl b/test/contrib/share_parameters.jl index 815e409aff..f9f85b6ed2 100644 --- a/test/contrib/share_parameters.jl +++ b/test/contrib/share_parameters.jl @@ -14,10 +14,10 @@ rng = get_stable_rng(12345) ps_1 = Lux.share_parameters(ps, sharing) - @test ps_1.d2.l2.weight === ps_1.d1.weight - @test ps_1.d2.l2.bias === ps_1.d1.bias - @test ps_1.d3.weight === ps_1.d2.l1.weight - @test ps_1.d3.bias === ps_1.d2.l1.bias + @test ps_1.d2.l2.weight == ps_1.d1.weight + @test ps_1.d2.l2.bias == ps_1.d1.bias + @test ps_1.d3.weight == ps_1.d2.l1.weight + @test ps_1.d3.bias == ps_1.d2.l1.bias ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> device @@ -26,20 +26,20 @@ rng = get_stable_rng(12345) ps_2 = Lux.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) - @test ps_2.d2.l2.weight === ps_new_1.weight === ps_2.d1.weight - @test ps_2.d2.l2.bias === ps_new_1.bias === ps_2.d1.bias - @test ps_2.d3.weight === ps_new_2.weight === ps_2.d2.l1.weight - @test ps_2.d3.bias === ps_new_2.bias === ps_2.d2.l1.bias + @test ps_2.d2.l2.weight == ps_new_1.weight == ps_2.d1.weight + @test ps_2.d2.l2.bias == ps_new_1.bias == ps_2.d1.bias + @test ps_2.d3.weight == ps_new_2.weight == ps_2.d2.l1.weight + @test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias # Mix in ComponentArray - ps_new_ca_1 = ComponentArray(ps_new_1) + ps_new_ca_1 = ComponentArray(ps_new_1 |> cpu) |> device ps_3 = Lux.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) - @test ps_3.d2.l2.weight === ps_new_ca_1.weight === ps_3.d1.weight - @test ps_3.d2.l2.bias === ps_new_ca_1.bias === ps_3.d1.bias - @test ps_3.d3.weight === ps_new_2.weight === ps_3.d2.l1.weight - @test ps_3.d3.bias === ps_new_2.bias === ps_3.d2.l1.bias + @test ps_3.d2.l2.weight == ps_new_ca_1.weight == ps_3.d1.weight + @test ps_3.d2.l2.bias == ps_new_ca_1.bias == ps_3.d1.bias + @test ps_3.d3.weight == ps_new_2.weight == ps_3.d2.l1.weight + @test ps_3.d3.bias == ps_new_2.bias == ps_3.d2.l1.bias # Input Checks non_disjoint_sharing = (("d2.l2", "d1"), ("d1", "d2.l1")) @@ -54,7 +54,7 @@ rng = get_stable_rng(12345) @test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) - ps_new_ca_1 = ComponentArray(ps_new_1) + ps_new_ca_1 = ComponentArray(ps_new_1 |> cpu) |> device @test_throws ArgumentError Lux.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) end diff --git a/test/layers/containers.jl b/test/layers/containers.jl index 11c9aa796a..ea7ab21659 100644 --- a/test/layers/containers.jl +++ b/test/layers/containers.jl @@ -311,8 +311,9 @@ end end @testset "complex alternatives" begin - layer = Maxout(WrappedFunction(x -> aType([0.5; 0.1]) * x), - WrappedFunction(x -> aType([0.2; 0.7]) * x)) + A = aType([0.5 0.1]') + B = aType([0.2 0.7]') + layer = Maxout(WrappedFunction(x -> A * x), WrappedFunction(x -> B * x)) display(layer) ps, st = Lux.setup(rng, layer) .|> device x = [3.0 2.0] |> aType diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 628d2b39ec..9ab5a5a399 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -1,10 +1,9 @@ -using Lux, NNlib, Random, Test, Zygote - -rng = Random.default_rng() -Random.seed!(rng, 0) +using Lux, NNlib, Test, Zygote include("../test_utils.jl") +rng = get_stable_rng(12345) + @testset "$mode: Pooling" for (mode, aType, device, ongpu) in MODES x = randn(rng, Float32, 10, 10, 3, 2) |> aType y = randn(rng, Float32, 20, 20, 3, 2) |> aType @@ -147,9 +146,9 @@ end @jet layer(x, ps, st) end - @testset "Variable BitWidth Parameters" begin - # https://github.com/FluxML/Flux.jl/issues/1421 - layer = Conv((5, 5), 10 => 20, identity; init_weight=Base.randn, + @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin + layer = Conv((5, 5), 10 => 20, identity; + init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) display(layer) ps, st = Lux.setup(rng, layer) @@ -211,7 +210,7 @@ end __f = (x, ps) -> sum(first(layer(x, ps, st))) @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - layer = Conv(k, 1 => 1; pad=Lux.SamePad(), dilation=k .÷ 2) + layer = Conv(k, 1 => 1; pad=Lux.SamePad(), dilation=max.(k .÷ 2, 1)) display(layer) ps, st = Lux.setup(rng, layer) .|> device @@ -233,7 +232,7 @@ end @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end - @testset "Conv with non quadratic window #700" begin + @testset "Conv with non quadratic window FluxML/Flux.jl#700" begin x = zeros(Float32, 7, 7, 1, 1) x[4, 4, 1, 1] = 1 x = x |> aType @@ -416,9 +415,9 @@ end @jet layer(x, ps, st) end - @testset "Variable BitWidth Parameters" begin - # https://github.com/FluxML/Flux.jl/issues/1421 - layer = CrossCor((5, 5), 10 => 20, identity; init_weight=Base.randn, + @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin + layer = CrossCor((5, 5), 10 => 20, identity; + init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) display(layer) ps, st = Lux.setup(rng, layer) @@ -439,7 +438,7 @@ end __f = (x, ps) -> sum(first(layer(x, ps, st))) @eval @test_gradients $__f $x $ps gpu_testing=$ongpu atol=1e-3 rtol=1e-3 - layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), dilation=k .÷ 2) + layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), dilation=max.(k .÷ 2, 1)) display(layer) ps, st = Lux.setup(rng, layer) .|> device diff --git a/test/layers/dropout.jl b/test/layers/dropout.jl index 9531715ae6..c84db98eed 100644 --- a/test/layers/dropout.jl +++ b/test/layers/dropout.jl @@ -36,7 +36,8 @@ end layer = AlphaDropout(p) display(layer) ps, st = Lux.setup(rng, layer) .|> device - x = randn(Float32, 5, 2) |> aType + # GPU compilation for mixed types fail atm + x = randn(typeof(p), 5, 2) |> aType x_, st_ = layer(x, ps, st) x__, st__ = layer(x, ps, st) diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index d73e2617b8..316981b2af 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -42,7 +42,7 @@ rng = get_stable_rng(12345) 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0]) - st_ = Lux.testmode(st_) + st_ = Lux.testmode(st_) |> device x_ = m(x, ps, st_)[1] |> cpu @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) @@ -91,7 +91,7 @@ rng = get_stable_rng(12345) m = BatchNorm(32; affine) x = randn(Float32, 416, 416, 32, 1) |> aType display(m) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) .|> device st = Lux.testmode(st) m(x, ps, st) @test (@allocated m(x, ps, st)) < 100_000_000 diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 3420d836ac..a0813712e8 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -61,187 +61,187 @@ rng = get_stable_rng(12345) # end # end -# @testset "$mode: LSTMCell" for (mode, aType, device, ongpu) in MODES -# for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), -# LSTMCell(3 => 5; use_bias=false)) -# display(lstmcell) -# ps, st = Lux.setup(rng, lstmcell) .|> device -# x = randn(rng, Float32, 3, 2) |> aType -# (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) - -# @jet lstmcell(x, ps, st) -# @jet lstmcell((x, carry), ps, st) - -# function loss_loop_lstmcell(p) -# (y, carry), st_ = lstmcell(x, p, st) -# for i in 1:10 -# (y, carry), st_ = lstmcell((x, carry), p, st_) -# end -# return sum(abs2, y) -# end - -# @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu - -# @test_throws ErrorException ps.train_state -# @test_throws ErrorException ps.train_memory -# end - -# @testset "Trainable hidden states" begin -# x = randn(rng, Float32, 3, 2) |> aType -# _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) -# _ps, _st = Lux.setup(rng, _lstm) .|> device -# (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) - -# lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) -# ps, st = Lux.setup(rng, lstm) .|> device -# ps = _ps -# (y, carry), _ = Lux.apply(lstm, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test_throws ErrorException gs.bias -# @test_throws ErrorException gs.hidden_state -# @test_throws ErrorException gs.memory - -# lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) -# ps, st = Lux.setup(rng, lstm) .|> device -# ps = merge(_ps, (hidden_state=ps.hidden_state,)) -# (y, carry), _ = Lux.apply(lstm, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test_throws ErrorException gs.bias -# @test !isnothing(gs.hidden_state) -# @test_throws ErrorException gs.memory - -# lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) -# ps, st = Lux.setup(rng, lstm) .|> device -# ps = merge(_ps, (memory=ps.memory,)) -# (y, carry), _ = Lux.apply(lstm, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test_throws ErrorException gs.bias -# @test_throws ErrorException gs.hidden_state -# @test !isnothing(gs.memory) - -# lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) -# ps, st = Lux.setup(rng, lstm) .|> device -# ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) -# (y, carry), _ = Lux.apply(lstm, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test_throws ErrorException gs.bias -# @test !isnothing(gs.hidden_state) -# @test !isnothing(gs.memory) - -# lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) -# ps, st = Lux.setup(rng, lstm) .|> device -# ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) -# (y, carry), _ = Lux.apply(lstm, x, ps, st) -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test !isnothing(gs.bias) -# @test !isnothing(gs.hidden_state) -# @test !isnothing(gs.memory) -# end -# end - -# @testset "$mode: GRUCell" for (mode, aType, device, ongpu) in MODES -# for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), -# GRUCell(3 => 5; use_bias=false)) -# display(grucell) -# ps, st = Lux.setup(rng, grucell) .|> device -# x = randn(rng, Float32, 3, 2) |> aType -# (y, carry), st_ = Lux.apply(grucell, x, ps, st) +@testset "$mode: LSTMCell" for (mode, aType, device, ongpu) in MODES + for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), + LSTMCell(3 => 5; use_bias=false)) + display(lstmcell) + ps, st = Lux.setup(rng, lstmcell) .|> device + x = randn(rng, Float32, 3, 2) |> aType + (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) + + @jet lstmcell(x, ps, st) + @jet lstmcell((x, carry), ps, st) + + function loss_loop_lstmcell(p) + (y, carry), st_ = lstmcell(x, p, st) + for i in 1:10 + (y, carry), st_ = lstmcell((x, carry), p, st_) + end + return sum(abs2, y) + end -# @jet grucell(x, ps, st) -# @jet grucell((x, carry), ps, st) + @eval @test_gradients $loss_loop_lstmcell $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu -# function loss_loop_grucell(p) -# (y, carry), st_ = grucell(x, p, st) -# for i in 1:10 -# (y, carry), st_ = grucell((x, carry), p, st_) -# end -# return sum(abs2, y) -# end + @test_throws ErrorException ps.train_state + @test_throws ErrorException ps.train_memory + end -# @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + @testset "Trainable hidden states" begin + x = randn(rng, Float32, 3, 2) |> aType + _lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) + _ps, _st = Lux.setup(rng, _lstm) .|> device + (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=false) + ps, st = Lux.setup(rng, lstm) .|> device + ps = _ps + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test_throws ErrorException gs.hidden_state + @test_throws ErrorException gs.memory + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=false) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state,)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test !isnothing(gs.hidden_state) + @test_throws ErrorException gs.memory + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=false, train_memory=true) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (memory=ps.memory,)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test_throws ErrorException gs.hidden_state + @test !isnothing(gs.memory) + + lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test !isnothing(gs.hidden_state) + @test !isnothing(gs.memory) + + lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) + ps, st = Lux.setup(rng, lstm) .|> device + ps = merge(_ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) + (y, carry), _ = Lux.apply(lstm, x, ps, st) + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test !isnothing(gs.bias) + @test !isnothing(gs.hidden_state) + @test !isnothing(gs.memory) + end +end -# @test_throws ErrorException ps.train_state -# end +@testset "$mode: GRUCell" for (mode, aType, device, ongpu) in MODES + for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), + GRUCell(3 => 5; use_bias=false)) + display(grucell) + ps, st = Lux.setup(rng, grucell) .|> device + x = randn(rng, Float32, 3, 2) |> aType + (y, carry), st_ = Lux.apply(grucell, x, ps, st) + + @jet grucell(x, ps, st) + @jet grucell((x, carry), ps, st) + + function loss_loop_grucell(p) + (y, carry), st_ = grucell(x, p, st) + for i in 1:10 + (y, carry), st_ = grucell((x, carry), p, st_) + end + return sum(abs2, y) + end -# @testset "Trainable hidden states" begin -# x = randn(rng, Float32, 3, 2) |> aType -# _gru = GRUCell(3 => 5; use_bias=false, train_state=false) -# _ps, _st = Lux.setup(rng, _gru) .|> device -# (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) - -# gru = GRUCell(3 => 5; use_bias=false, train_state=false) -# ps, st = Lux.setup(rng, gru) .|> device -# ps = _ps -# (y, carry), _ = Lux.apply(gru, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test_throws ErrorException gs.bias -# @test_throws ErrorException gs.hidden_state + @eval @test_gradients $loss_loop_grucell $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu -# gru = GRUCell(3 => 5; use_bias=false, train_state=true) -# ps, st = Lux.setup(rng, gru) .|> device -# ps = merge(_ps, (hidden_state=ps.hidden_state,)) -# (y, carry), _ = Lux.apply(gru, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test !isnothing(gs.hidden_state) + @test_throws ErrorException ps.train_state + end -# gru = GRUCell(3 => 5; use_bias=true, train_state=true) -# ps, st = Lux.setup(rng, gru) .|> device -# ps = merge(_ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) -# (y, carry), _ = Lux.apply(gru, x, ps, st) -# @test carry == _carry -# l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) -# gs = back(one(l))[1] -# @test !isnothing(gs.hidden_state) -# end -# end + @testset "Trainable hidden states" begin + x = randn(rng, Float32, 3, 2) |> aType + _gru = GRUCell(3 => 5; use_bias=false, train_state=false) + _ps, _st = Lux.setup(rng, _gru) .|> device + (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) + + gru = GRUCell(3 => 5; use_bias=false, train_state=false) + ps, st = Lux.setup(rng, gru) .|> device + ps = _ps + (y, carry), _ = Lux.apply(gru, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test_throws ErrorException gs.bias + @test_throws ErrorException gs.hidden_state + + gru = GRUCell(3 => 5; use_bias=false, train_state=true) + ps, st = Lux.setup(rng, gru) .|> device + ps = merge(_ps, (hidden_state=ps.hidden_state,)) + (y, carry), _ = Lux.apply(gru, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test !isnothing(gs.hidden_state) + + gru = GRUCell(3 => 5; use_bias=true, train_state=true) + ps, st = Lux.setup(rng, gru) .|> device + ps = merge(_ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) + (y, carry), _ = Lux.apply(gru, x, ps, st) + @test carry == _carry + l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) + gs = back(one(l))[1] + @test !isnothing(gs.hidden_state) + end +end -# @testset "$mode: StatefulRecurrentCell" for (mode, aType, device, ongpu) in MODES -# for _cell in (RNNCell, LSTMCell, GRUCell), -# use_bias in (true, false), -# train_state in (true, false) +@testset "$mode: StatefulRecurrentCell" for (mode, aType, device, ongpu) in MODES + for _cell in (RNNCell, LSTMCell, GRUCell), + use_bias in (true, false), + train_state in (true, false) -# cell = _cell(3 => 5; use_bias, train_state) -# rnn = StatefulRecurrentCell(cell) -# display(rnn) -# x = randn(rng, Float32, 3, 2) |> aType -# ps, st = Lux.setup(rng, rnn) .|> device + cell = _cell(3 => 5; use_bias, train_state) + rnn = StatefulRecurrentCell(cell) + display(rnn) + x = randn(rng, Float32, 3, 2) |> aType + ps, st = Lux.setup(rng, rnn) .|> device -# y, st_ = rnn(x, ps, st) + y, st_ = rnn(x, ps, st) -# @jet rnn(x, ps, st) -# @jet rnn(x, ps, st_) + @jet rnn(x, ps, st) + @jet rnn(x, ps, st_) -# @test size(y) == (5, 2) -# @test st.carry === nothing -# @test st_.carry !== nothing + @test size(y) == (5, 2) + @test st.carry === nothing + @test st_.carry !== nothing -# st__ = Lux.update_state(st, :carry, nothing) -# @test st__.carry === nothing + st__ = Lux.update_state(st, :carry, nothing) + @test st__.carry === nothing -# function loss_loop_rnn(p) -# y, st_ = rnn(x, p, st) -# for i in 1:10 -# y, st_ = rnn(x, p, st_) -# end -# return sum(abs2, y) -# end + function loss_loop_rnn(p) + y, st_ = rnn(x, p, st) + for i in 1:10 + y, st_ = rnn(x, p, st_) + end + return sum(abs2, y) + end -# @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu -# end -# end + @eval @test_gradients $loss_loop_rnn $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + end +end @testset "$mode: Recurrence" for (mode, aType, device, ongpu) in MODES for _cell in (RNNCell, LSTMCell, GRUCell), diff --git a/test/test_utils.jl b/test/test_utils.jl index 9c028f20d6..8f72f53ef1 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -2,6 +2,8 @@ using Lux, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote using LuxCUDA # CUDA Support using LuxTestUtils: @jet, @test_gradients, check_approx +CUDA.allowscalar(false) + const GROUP = get(ENV, "GROUP", "All") cpu_testing() = GROUP == "All" || GROUP == "CPU" diff --git a/test/utils.jl b/test/utils.jl index 64cbfe6b29..da2ae2bcc2 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -68,15 +68,18 @@ end @jet Lux.multigate(x, Val(2)) x = rand(6, 5) |> aType - res, (dx,) = Zygote.withgradient(x) do x + __f = x -> begin x1, _, x3 = Lux.multigate(x, Val(3)) return sum(x1) + sum(x3 .* 2) end + res, (dx,) = Zygote.withgradient(__f, x) @jet Lux.multigate(x, Val(3)) @test res == sum(x[1:2, :]) + 2sum(x[5:6, :]) - @test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)] + @test dx == aType([ones(2, 5); zeros(2, 5); fill(2, 2, 5)]) + + @eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu end @testset "$mode: ComponentArrays" for (mode, aType, device, ongpu) in MODES