Skip to content

Commit

Permalink
Merge pull request #297 from LuxDL/ap/testing
Browse files Browse the repository at this point in the history
Testing using LuxTestUtils.jl
  • Loading branch information
avik-pal authored Apr 27, 2023
2 parents 3d48944 + 664f40a commit b3c4188
Show file tree
Hide file tree
Showing 31 changed files with 1,305 additions and 1,376 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.4.51"
version = "0.4.52"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -48,7 +48,7 @@ Flux = "0.13"
Functors = "0.2, 0.3, 0.4"
LuxCUDA = "0.1"
LuxCore = "0.1.3"
LuxLib = "0.1.7"
LuxLib = "0.2"
NNlib = "0.8"
Optimisers = "0.2"
Requires = "1"
Expand Down
8 changes: 7 additions & 1 deletion ext/LuxComponentArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ function Lux._merge(ca::ComponentArray, p::AbstractArray)
return ca
end

# Empty NamedTuple: Hack to avoid breaking precompilation
function ComponentArrays.ComponentArray(data::Vector{Any}, axes::Tuple{FlatAxis})
length(data) == 0 && return ComponentArray(Float32[], axes)
return ComponentArray{Any, 1, typeof(data), typeof(axes)}(data, axes)
end

# Parameter Sharing
Lux._parameter_structure(ps::ComponentArray) = Lux._parameter_structure(NamedTuple(ps))

Expand All @@ -54,7 +60,7 @@ function CRC.rrule(::Type{ComponentArray}, nt::NamedTuple)
"of shape $(size(res)) & type $(typeof(res))")
return nothing
end
CA_NT_pullback::ComponentArray) = (CRC.NoTangent(), NamedTuple))
CA_NT_pullback::ComponentArray) = CRC.NoTangent(), NamedTuple(Δ)
return res, CA_NT_pullback
end

Expand Down
17 changes: 16 additions & 1 deletion ext/LuxComponentArraysTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ else
using ..Tracker
end

Tracker.param(ca::ComponentArray) = ComponentArray(Tracker.param(getdata(ca)), getaxes(ca))
function Tracker.param(ca::ComponentArray)
x = getdata(ca)
length(x) == 0 && return ComponentArray(Tracker.param(Float32[]), getaxes(ca))
return ComponentArray(Tracker.param(x), getaxes(ca))
end

Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca))

Expand All @@ -24,4 +28,15 @@ function Base.getindex(g::Tracker.Grads, x::ComponentArray)
return g[Tracker.tracker(getdata(x))]
end

# For TrackedArrays ignore Base.maybeview
## Tracker with views doesn't work quite well
@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray},
s::Symbol) where {T}
return getproperty(x, Val(s))
end

@inline function Base.getproperty(x::ComponentVector{T, <:TrackedArray}, v::Val) where {T}
return ComponentArrays._getindex(Base.getindex, x, v)
end

end
43 changes: 12 additions & 31 deletions ext/LuxFluxTransformExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,10 @@ m2(x, ps, st)
```
"""
function transform(l::T; preserve_ps_st::Bool=false, kwargs...) where {T}
@warn """Transformation for type $T not implemented. Using `FluxLayer` as
a fallback.""" maxlog=1
@warn "Transformation for type $T not implemented. Using `FluxLayer` as a fallback." maxlog=1

if !preserve_ps_st
@warn """`FluxLayer` uses the parameters and states of the `layer`. It is not
possible to NOT preserve the parameters and states. Ignoring this keyword
argument.""" maxlog=1
@warn "`FluxLayer` uses the parameters and states of the `layer`. It is not possible to NOT preserve the parameters and states. Ignoring this keyword argument." maxlog=1
end

return FluxLayer(l)
Expand Down Expand Up @@ -168,8 +165,7 @@ function transform(l::Flux.Parallel; kwargs...)
end

function transform(l::Flux.PairwiseFusion; kwargs...)
@warn """Flux.PairwiseFusion and Lux.PairwiseFusion are semantically different. Using
`FluxLayer` as a fallback.""" maxlog=1
@warn "Flux.PairwiseFusion and Lux.PairwiseFusion are semantically different. Using `FluxLayer` as a fallback." maxlog=1
return FluxLayer(l)
end

Expand Down Expand Up @@ -252,8 +248,7 @@ end
transform(l::Flux.Dropout; kwargs...) = Dropout(l.p; l.dims)

function transform(l::Flux.LayerNorm; kwargs...)
@warn """Flux.LayerNorm and Lux.LayerNorm are semantically different specifications.
Using `FluxLayer` as a fallback.""" maxlog=1
@warn "Flux.LayerNorm and Lux.LayerNorm are semantically different specifications. Using `FluxLayer` as a fallback." maxlog=1
return FluxLayer(l)
end

Expand All @@ -273,13 +268,9 @@ function transform(l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve::
out_dims, in_dims = size(l.Wi)
if preserve_ps_st
if force_preserve
throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a " *
"`reset!` mechanism which hasn't been " *
"extensively tested with `FluxLayer`. Rewrite " *
"the model manually to use `RNNCell`."))
throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `RNNCell`."))
end
@warn """Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux
and hence not supported. Ignoring these parameters.""" maxlog=1
@warn "Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1
return RNNCell(in_dims => out_dims, l.σ; init_bias=(args...) -> copy(l.b),
init_state=(args...) -> copy(l.state0))
else
Expand All @@ -292,13 +283,9 @@ function transform(l::Flux.LSTMCell; preserve_ps_st::Bool=false, force_preserve:
out_dims = _out_dims ÷ 4
if preserve_ps_st
if force_preserve
throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a " *
"`reset!` mechanism which hasn't been " *
"extensively tested with `FluxLayer`. Rewrite " *
"the model manually to use `LSTMCell`."))
throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `LSTMCell`."))
end
@warn """Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux
and hence not supported. Ignoring these parameters.""" maxlog=1
@warn "Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1
bs = Lux.multigate(l.b, Val(4))
_s, _m = copy.(l.state0)
return LSTMCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs),
Expand All @@ -313,13 +300,9 @@ function transform(l::Flux.GRUCell; preserve_ps_st::Bool=false, force_preserve::
out_dims = _out_dims ÷ 3
if preserve_ps_st
if force_preserve
throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a " *
"`reset!` mechanism which hasn't been " *
"extensively tested with `FluxLayer`. Rewrite " *
"the model manually to use `GRUCell`."))
throw(FluxModelConversionError("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `GRUCell`."))
end
@warn """Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux
and hence not supported. Ignoring these parameters.""" maxlog=1
@warn "Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1
bs = Lux.multigate(l.b, Val(3))
return GRUCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs),
init_state=(args...) -> copy(l.state0))
Expand All @@ -333,8 +316,7 @@ function transform(l::Flux.BatchNorm; preserve_ps_st::Bool=false,
if preserve_ps_st
if l.track_stats
force_preserve && return FluxLayer(l)
@warn """Preserving the state of `Flux.BatchNorm` is currently not supported.
Ignoring the state.""" maxlog=1
@warn "Preserving the state of `Flux.BatchNorm` is currently not supported. Ignoring the state." maxlog=1
end
if l.affine
return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum,
Expand All @@ -352,8 +334,7 @@ function transform(l::Flux.GroupNorm; preserve_ps_st::Bool=false,
if preserve_ps_st
if l.track_stats
force_preserve && return FluxLayer(l)
@warn """Preserving the state of `Flux.GroupNorm` is currently not supported.
Ignoring the state.""" maxlog=1
@warn "Preserving the state of `Flux.GroupNorm` is currently not supported. Ignoring the state." maxlog=1
end
if l.affine
return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ,
Expand Down
6 changes: 5 additions & 1 deletion ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LuxTrackerExt

isdefined(Base, :get_extension) ? (using Tracker) : (using ..Tracker)
using Functors, Lux, Setfield
using ChainRulesCore, Functors, Lux, Setfield

# Type Piracy: Need to upstream
Tracker.param(nt::NamedTuple) = fmap(Tracker.param, nt)
Expand All @@ -18,6 +18,10 @@ Tracker.data(t::Tuple) = map(Tracker.data, t)
# Weight Norm Patch
@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims))

# multigate chain rules
@inline Lux._gate(x::TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)]
@inline Lux._gate(x::TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :]

# Lux.Training
function Lux.Training.compute_gradients(::Lux.Training.TrackerVJP,
objective_function::Function, data,
Expand Down
11 changes: 10 additions & 1 deletion ext/LuxZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
module LuxZygoteExt

isdefined(Base, :get_extension) ? (using Zygote) : (using ..Zygote)
if isdefined(Base, :get_extension)
using Zygote
using Zygote: Pullback
else
using ..Zygote
using ..Zygote: Pullback
end

using Adapt, LuxCUDA, Lux, Setfield
using TruncatedStacktraces: @truncate_stacktrace

Adapt.adapt_storage(::Lux.LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))

Expand All @@ -19,4 +26,6 @@ function Lux.Training.compute_gradients(::Lux.Training.ZygoteVJP,
return grads, loss, stats, ts
end

@truncate_stacktrace Pullback 1

end
2 changes: 1 addition & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ function CRC.rrule(::typeof(multigate), x::AbstractArray, c::Val{N}) where {N}
dyᵢ isa AbstractZero && return
@. dxᵢ += dyᵢ
end
return (NoTangent(), dx, NoTangent(), NoTangent())
return (NoTangent(), dx, NoTangent())
end
return multigate(x, c), multigate_pullback
end
Expand Down
10 changes: 0 additions & 10 deletions src/contrib/freeze.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ function initialstates(rng::AbstractRNG, l::FrozenLayer{which_params}) where {wh
return (frozen_params=(; ps_frozen...), states=st)
end

_merge(nt1::NamedTuple, nt2::NamedTuple) = merge(nt1, nt2)
function _merge(p::AbstractArray, nt::NamedTuple)
@assert length(p) == 0
return nt
end
function _merge(nt::NamedTuple, p::AbstractArray)
@assert length(p) == 0
return nt
end

function (f::FrozenLayer)(x, ps, st::NamedTuple)
y, st_ = f.layer(x, _merge(ps, st.frozen_params), st.states)
st = merge(st, (; states=st_))
Expand Down
7 changes: 2 additions & 5 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,7 @@ function initialparameters(rng::AbstractRNG,
v = ps_layer[k]
if k in which_params
if all(iszero, v)
msg = ("Parameter $(k) is completely zero. This will result in NaN " *
"gradients. Either remove this parameter from `which_params` or " *
"modify the initialization in the actual layer. Typically this is " *
"controlled using the `init_$(k)` keyword argument.")
msg = ("Parameter $(k) is completely zero. This will result in NaN gradients. Either remove this parameter from `which_params` or modify the initialization in the actual layer. Typically this is controlled using the `init_$(k)` keyword argument.")
# FIXME(@avik-pal): This is not really an ArgumentError
throw(ArgumentError(msg))
end
Expand All @@ -510,7 +507,7 @@ initialstates(rng::AbstractRNG, wn::WeightNorm) = initialstates(rng, wn.layer)

function (wn::WeightNorm)(x, ps, st::NamedTuple)
_ps = _get_normalized_parameters(wn, wn.dims, ps.normalized)
return Lux.apply(wn.layer, x, merge(_ps, ps.unnormalized), st)
return Lux.apply(wn.layer, x, _merge(_ps, ps.unnormalized), st)
end

@inbounds @generated function _get_normalized_parameters(::WeightNorm{which_params},
Expand Down
16 changes: 15 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ end
Split up `x` into `N` equally sized chunks (along dimension `1`).
"""
@inline multigate(x::AbstractArray, ::Val{N}) where {N} = _gate.((x,), size(x, 1) ÷ N, 1:N)
@inline function multigate(x::AbstractArray, ::Val{N}) where {N}
# return map(i -> _gate(x, size(x, 1) ÷ N, i), 1:N)
return ntuple(i -> _gate(x, size(x, 1) ÷ N, i), N)
end

# Val utilities
get_known(::Val{T}) where {T} = T
Expand Down Expand Up @@ -288,3 +291,14 @@ in the backward pass.
"""
@inline foldl_init(op, x) = foldl_init(op, x, nothing)
@inline foldl_init(op, x, init) = foldl(op, x; init)

# Merging Exotic Types
_merge(nt1::NamedTuple, nt2::NamedTuple) = merge(nt1, nt2)
function _merge(p::AbstractArray, nt::NamedTuple)
@assert length(p) == 0
return nt
end
function _merge(nt::NamedTuple, p::AbstractArray)
@assert length(p) == 0
return nt
end
2 changes: 2 additions & 0 deletions test/LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[LuxTestUtils]
target_modules = ["Lux", "LuxCore", "LuxLib"]
12 changes: 5 additions & 7 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down
Loading

2 comments on commit b3c4188

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/82434

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.52 -m "<description of version>" b3c418811b0a288567cc5c0e1417068e1992fc85
git push origin v0.4.52

Please sign in to comment.