From 30187071a0491844161ebcf26fc184e7bad2e38f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 31 Aug 2022 18:12:49 -0700 Subject: [PATCH] Named Layers inside Generic Containers (#143) --- CHANGELOG.md | 7 +- Project.toml | 2 +- src/Lux.jl | 1 + src/layers/basic.jl | 400 ---------------------------------- src/layers/containers.jl | 439 ++++++++++++++++++++++++++++++++++++++ test/layers/basic.jl | 125 ----------- test/layers/containers.jl | 238 +++++++++++++++++++++ test/runtests.jl | 1 + test/test_utils.jl | 4 + 9 files changed, 690 insertions(+), 527 deletions(-) create mode 100644 src/layers/containers.jl create mode 100644 test/layers/containers.jl diff --git a/CHANGELOG.md b/CHANGELOG.md index aa6706d8dd..179098d9d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,13 @@ # v0.4 +## v0.4.19 + + - Generic Container layers (like `Chain`, `Parallel`, etc.) can now used custom naming for + their internal layers. + ## v0.4.17 - - Major breakcing change in experimental Recurrent Cell Implementations. + - Major breaking change in experimental Recurrent Cell Implementations. ## v0.4.14 - Deprecate `bias` in favor of `use_bias` for `RNNCell`. diff --git a/Project.toml b/Project.toml index e3ba267201..360f20b1dd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.18" +version = "0.4.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/Lux.jl b/src/Lux.jl index 6f261eb9dc..cd6b84e527 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -32,6 +32,7 @@ include("core.jl") include("adapt.jl") # Layer Implementations include("layers/basic.jl") +include("layers/containers.jl") include("layers/normalize.jl") include("layers/conv.jl") include("layers/dropout.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index bc79f6dcd6..8e8573112e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -122,406 +122,6 @@ function Base.show(io::IO, w::WrappedFunction) return print(io, "WrappedFunction(", w.func, ")") end -""" - SkipConnection(layer, connection) - -Create a skip connection which consists of a layer or [`Chain`](@ref) of consecutive layers -and a shortcut connection linking the block's input to the output through a user-supplied -2-argument callable. The first argument to the callable will be propagated through the given -`layer` while the second is the unchanged, "skipped" input. - -The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. - -## Arguments - - - `layer`: Layer or `Chain` of layers to be applied to the input - - `connection`: A 2-argument function that takes `layer(input)` and the input - -## Inputs - - - `x`: Will be passed directly to `layer` - -## Returns - - - Output of `connection(layer(input), input)` - - Updated state of `layer` - -## Parameters - - - Parameters of `layer` - -## States - - - States of `layer` - -See [`Parallel`](@ref) for a more general implementation. -""" -struct SkipConnection{T <: AbstractExplicitLayer, F} <: - AbstractExplicitContainerLayer{(:layers,)} - layers::T - connection::F -end - -@inline function (skip::SkipConnection)(x, ps, st::NamedTuple) - mx, st = skip.layers(x, ps, st) - return skip.connection(mx, x), st -end - -""" - Parallel(connection, layers...) - -Create a layer which passes an input to each path in `layers`, before reducing the output -with `connection`. - -## Arguments - - - `layers`: A list of `N` Lux layers - - `connection`: An `N`-argument function that is called after passing the input through - each layer. If `connection = nothing`, we return a tuple - `Parallel(nothing, f, g)(x, y) = (f(x), g(y))` - -## Inputs - - - `x`: If `x` is not a tuple, then return is computed as - `connection([l(x) for l in layers]...)`. Else one is passed to each layer, thus - `Parallel(+, f, g)(x, y) = f(x) + g(y)`. - -## Returns - - - See the Inputs section for how the output is computed - - Updated state of the `layers` - -## Parameters - - - Parameters of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -## States - - - States of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -See also [`SkipConnection`](@ref) which is `Parallel` with one identity. -""" -struct Parallel{F, T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} - connection::F - layers::T -end - -function Parallel(connection, layers...) - names = ntuple(i -> Symbol("layer_$i"), length(layers)) - return Parallel(connection, NamedTuple{names}(layers)) -end - -function (m::Parallel)(x, ps, st::NamedTuple) - return applyparallel(m.layers, m.connection, x, ps, st) -end - -@generated function applyparallel(layers::NamedTuple{names}, connection::C, x::T, ps, - st::NamedTuple) where {names, C, T} - N = length(names) - y_symbols = [gensym() for _ in 1:(N + 1)] - st_symbols = [gensym() for _ in 1:N] - getinput(i) = T <: Tuple ? :(x[$i]) : :x - calls = [] - append!(calls, - [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(getinput(i)), - ps.$(names[i]), - st.$(names[i]))) - for i in 1:N]) - push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) - if C == Nothing - push!(calls, :($(y_symbols[N + 1]) = tuple($(Tuple(y_symbols[1:N])...)))) - else - push!(calls, :($(y_symbols[N + 1]) = connection($(Tuple(y_symbols[1:N])...)))) - end - push!(calls, :(return $(y_symbols[N + 1]), st)) - return Expr(:block, calls...) -end - -Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) - -""" - BranchLayer(layers...) - -Takes an input `x` and passes it through all the `layers` and returns a tuple of the -outputs. - -## Arguments - - - `layers`: A list of `N` Lux layers - -## Inputs - - - `x`: Will be directly passed to each of the `layers` - -## Returns - - - Tuple: `(layer_1(x), layer_2(x), ..., layer_N(x))` - - Updated state of the `layers` - -## Parameters - - - Parameters of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -## States - - - States of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -## Comparison with [`Parallel`](@ref) - -This is slightly different from `Parallel(nothing, layers...)` - - - If the input is a tuple, `Parallel` will pass each element individually to each layer - - - `BranchLayer` essentially assumes 1 input comes in and is branched out into `N` outputs - -## Example - -An easy way to replicate an input to an NTuple is to do - -```julia -l = BranchLayer(NoOpLayer(), NoOpLayer(), NoOpLayer()) -``` -""" -struct BranchLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} - layers::T -end - -function BranchLayer(layers...) - names = ntuple(i -> Symbol("layer_$i"), length(layers)) - return BranchLayer(NamedTuple{names}(layers)) -end - -function (m::BranchLayer)(x, ps, st::NamedTuple) - return applybranching(m.layers, x, ps, st) -end - -@generated function applybranching(layers::NamedTuple{names}, x, ps, - st::NamedTuple) where {names} - N = length(names) - y_symbols = [gensym() for _ in 1:N] - st_symbols = [gensym() for _ in 1:N] - calls = [] - append!(calls, - [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i](x, ps.$(names[i]), - st.$(names[i]))) - for i in 1:N]) - push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) - push!(calls, :(return tuple($(Tuple(y_symbols)...)), st)) - return Expr(:block, calls...) -end - -Base.keys(m::BranchLayer) = Base.keys(getfield(m, :layers)) - -""" - PairwiseFusion(connection, layers...) - -``` -x1 → layer1 → y1 ↘ - connection → layer2 → y2 ↘ - x2 ↗ connection → y3 - x3 ↗ -``` - -## Arguments - - - `connection`: Takes 2 inputs and combines them - - `layers`: [`AbstractExplicitLayer`](@ref)s - -## Inputs - -Layer behaves differently based on input type: - - 1. If the input `x` is a tuple of length `N + 1`, then the `layers` must be a tuple of - length `N`. The computation is as follows - -```julia -y = x[1] -for i in 1:N - y = connection(x[i + 1], layers[i](y)) -end -``` - - 2. Any other kind of input - -```julia -y = x -for i in 1:N - y = connection(x, layers[i](y)) -end -``` - -## Returns - - - See Inputs section for how the return value is computed - - Updated model state for all the contained layers - -## Parameters - - - Parameters of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -## States - - - States of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` -""" -struct PairwiseFusion{F, T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} - connection::F - layers::T -end - -function PairwiseFusion(connection, layers...) - names = ntuple(i -> Symbol("layer_$i"), length(layers)) - return PairwiseFusion(connection, NamedTuple{names}(layers)) -end - -function (m::PairwiseFusion)(x, ps, st::NamedTuple) - return applypairwisefusion(m.layers, m.connection, x, ps, st) -end - -@generated function applypairwisefusion(layers::NamedTuple{names}, connection::C, x::T, ps, - st::NamedTuple) where {names, C, T} - N = length(names) - y_symbols = [gensym() for _ in 1:(N + 1)] - st_symbols = [gensym() for _ in 1:N] - getinput(i) = T <: Tuple ? :(x[$i]) : :x - calls = [:($(y_symbols[N + 1]) = $(getinput(1)))] - append!(calls, - [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(y_symbols[N + 1]), - ps.$(names[i]), - st.$(names[i])); - $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))) - for i in 1:N]) - push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) - push!(calls, :(return $(y_symbols[N + 1]), st)) - return Expr(:block, calls...) -end - -Base.keys(m::PairwiseFusion) = Base.keys(getfield(m, :layers)) - -""" - Chain(layers...; disable_optimizations::Bool = false) - -Collects multiple layers / functions to be called in sequence on a given input. - -## Arguments - - - `layers`: A list of `N` Lux layers - -## Keyword Arguments - - - `disable_optimizations`: Prevents any structural optimization - -## Inputs - -Input `x` is passed sequentially to each layer, and must conform to the input requirements -of the internal layers. - -## Returns - - - Output after sequentially applying all the layers to `x` - - Updated model states - -## Parameters - - - Parameters of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -## States - - - States of each `layer` wrapped in a NamedTuple with - `fields = layer_1, layer_2, ..., layer_N` - -## Optimizations - -Performs a few optimizations to generate reasonable architectures. Can be disabled using -keyword argument `disable_optimizations`. - - - All sublayers are recursively optimized. - - If a function `f` is passed as a layer and it doesn't take 3 inputs, it is converted to - a [`WrappedFunction`](@ref)(`f`) which takes only one input. - - If the layer is a Chain, it is flattened. - - [`NoOpLayer`](@ref)s are removed. - - If there is only 1 layer (left after optimizations), then it is returned without the - `Chain` wrapper. - - If there are no layers (left after optimizations), a [`NoOpLayer`](@ref) is returned. - -## Example - -```julia -c = Chain(Dense(2, 3, relu), BatchNorm(3), Dense(3, 2)) -``` -""" -struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} - layers::T - function Chain(xs...; disable_optimizations::Bool=false) - xs = disable_optimizations ? xs : flatten_model(xs) - length(xs) == 0 && return NoOpLayer() - length(xs) == 1 && return first(xs) - names = ntuple(i -> Symbol("layer_$i"), length(xs)) - layers = NamedTuple{names}(xs) - return new{typeof(layers)}(layers) - end - function Chain(xs::AbstractVector; disable_optimizations::Bool=false) - return Chain(xs...; disable_optimizations) - end -end - -function flatten_model(layers::Union{AbstractVector, Tuple}) - new_layers = [] - for l in layers - f = flatten_model(l) - if f isa Tuple || f isa AbstractVector - append!(new_layers, f) - elseif f isa Function - if !hasmethod(f, (Any, Union{ComponentArray, NamedTuple}, NamedTuple)) - if f === identity - continue - else - push!(new_layers, WrappedFunction(f)) - end - else - push!(new_layers, f) - end - elseif f isa Chain - append!(new_layers, f.layers) - elseif f isa NoOpLayer - continue - else - push!(new_layers, f) - end - end - return layers isa AbstractVector ? new_layers : Tuple(new_layers) -end - -flatten_model(x) = x - -function (c::Chain)(x, ps, st::NamedTuple) - return applychain(c.layers, x, ps, st) -end - -@generated function applychain(layers::NamedTuple{fields}, x, ps, - st::NamedTuple{fields}) where {fields} - N = length(fields) - x_symbols = [gensym() for _ in 1:N] - st_symbols = [gensym() for _ in 1:N] - calls = [:(($(x_symbols[1]), $(st_symbols[1])) = layers[1](x, ps.layer_1, st.layer_1))] - append!(calls, - [:(($(x_symbols[i]), $(st_symbols[i])) = layers[$i]($(x_symbols[i - 1]), - ps.$(fields[i]), - st.$(fields[i]))) - for i in 2:N]) - push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) - push!(calls, :(return $(x_symbols[N]), st)) - return Expr(:block, calls...) -end - -Base.keys(m::Chain) = Base.keys(getfield(m, :layers)) - """ Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) diff --git a/src/layers/containers.jl b/src/layers/containers.jl new file mode 100644 index 0000000000..d60e216510 --- /dev/null +++ b/src/layers/containers.jl @@ -0,0 +1,439 @@ + +""" + SkipConnection(layer, connection) + +Create a skip connection which consists of a layer or [`Chain`](@ref) of consecutive layers +and a shortcut connection linking the block's input to the output through a user-supplied +2-argument callable. The first argument to the callable will be propagated through the given +`layer` while the second is the unchanged, "skipped" input. + +The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. + +## Arguments + + - `layer`: Layer or `Chain` of layers to be applied to the input + - `connection`: A 2-argument function that takes `layer(input)` and the input + +## Inputs + + - `x`: Will be passed directly to `layer` + +## Returns + + - Output of `connection(layer(input), input)` + - Updated state of `layer` + +## Parameters + + - Parameters of `layer` + +## States + + - States of `layer` + +See [`Parallel`](@ref) for a more general implementation. +""" +struct SkipConnection{T <: AbstractExplicitLayer, F} <: + AbstractExplicitContainerLayer{(:layers,)} + layers::T + connection::F +end + +@inline function (skip::SkipConnection)(x, ps, st::NamedTuple) + mx, st = skip.layers(x, ps, st) + return skip.connection(mx, x), st +end + +""" + Parallel(connection, layers...) + Parallel(connection; layers...) + +Create a layer which passes an input to each path in `layers`, before reducing the output +with `connection`. + +## Arguments + + - `connection`: An `N`-argument function that is called after passing the input through + each layer. If `connection = nothing`, we return a tuple + `Parallel(nothing, f, g)(x, y) = (f(x), g(y))` + + - Layers can be specified in two formats: + + + A list of `N` Lux layers + + Specified as `N` keyword arguments. + +## Inputs + + - `x`: If `x` is not a tuple, then return is computed as + `connection([l(x) for l in layers]...)`. Else one is passed to each layer, thus + `Parallel(+, f, g)(x, y) = f(x) + g(y)`. + +## Returns + + - See the Inputs section for how the output is computed + - Updated state of the `layers` + +## Parameters + + - Parameters of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +## States + + - States of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +See also [`SkipConnection`](@ref) which is `Parallel` with one identity. +""" +struct Parallel{F, T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} + connection::F + layers::T +end + +function Parallel(connection, layers...) + names = ntuple(i -> Symbol("layer_$i"), length(layers)) + return Parallel(connection, NamedTuple{names}(layers)) +end + +Parallel(connection; kwargs...) = Parallel(connection, (; kwargs...)) + +function (m::Parallel)(x, ps, st::NamedTuple) + return applyparallel(m.layers, m.connection, x, ps, st) +end + +@generated function applyparallel(layers::NamedTuple{names}, connection::C, x::T, ps, + st::NamedTuple) where {names, C, T} + N = length(names) + y_symbols = [gensym() for _ in 1:(N + 1)] + st_symbols = [gensym() for _ in 1:N] + getinput(i) = T <: Tuple ? :(x[$i]) : :x + calls = [] + append!(calls, + [:(($(y_symbols[i]), $(st_symbols[i])) = layers.$(names[i])($(getinput(i)), + ps.$(names[i]), + st.$(names[i]))) + for i in 1:N]) + push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) + if C == Nothing + push!(calls, :($(y_symbols[N + 1]) = tuple($(Tuple(y_symbols[1:N])...)))) + else + push!(calls, :($(y_symbols[N + 1]) = connection($(Tuple(y_symbols[1:N])...)))) + end + push!(calls, :(return $(y_symbols[N + 1]), st)) + return Expr(:block, calls...) +end + +Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) + +""" + BranchLayer(layers...) + BranchLayer(; layers...) + +Takes an input `x` and passes it through all the `layers` and returns a tuple of the +outputs. + +## Arguments + + - Layers can be specified in two formats: + + + A list of `N` Lux layers + + Specified as `N` keyword arguments. + +## Inputs + + - `x`: Will be directly passed to each of the `layers` + +## Returns + + - Tuple: `(layer_1(x), layer_2(x), ..., layer_N(x))` (naming changes if using the kwargs + API) + - Updated state of the `layers` + +## Parameters + + - Parameters of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +## States + + - States of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +!!! note "Comparison with Parallel" + + This is slightly different from [`Parallel(nothing, layers...)`](@ref) + + - If the input is a tuple, `Parallel` will pass each element individually to each + layer. + + - `BranchLayer` essentially assumes 1 input comes in and is branched out into `N` + outputs. + +## Example + +An easy way to replicate an input to an NTuple is to do + +```julia +l = BranchLayer(NoOpLayer(), NoOpLayer(), NoOpLayer()) +``` +""" +struct BranchLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} + layers::T +end + +function BranchLayer(layers...) + names = ntuple(i -> Symbol("layer_$i"), length(layers)) + return BranchLayer(NamedTuple{names}(layers)) +end + +BranchLayer(; kwargs...) = BranchLayer((; kwargs...)) + +function (m::BranchLayer)(x, ps, st::NamedTuple) + return applybranching(m.layers, x, ps, st) +end + +@generated function applybranching(layers::NamedTuple{names}, x, ps, + st::NamedTuple) where {names} + N = length(names) + y_symbols = [gensym() for _ in 1:N] + st_symbols = [gensym() for _ in 1:N] + calls = [] + append!(calls, + [:(($(y_symbols[i]), $(st_symbols[i])) = layers.$(names[i])(x, ps.$(names[i]), + st.$(names[i]))) + for i in 1:N]) + push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) + push!(calls, :(return tuple($(Tuple(y_symbols)...)), st)) + return Expr(:block, calls...) +end + +Base.keys(m::BranchLayer) = Base.keys(getfield(m, :layers)) + +""" + PairwiseFusion(connection, layers...) + PairwiseFusion(connection; layers...) + +``` +x1 → layer1 → y1 ↘ + connection → layer2 → y2 ↘ + x2 ↗ connection → y3 + x3 ↗ +``` + +## Arguments + + - `connection`: Takes 2 inputs and combines them + + - `layers`: [`AbstractExplicitLayer`](@ref)s. Layers can be specified in two formats: + + + A list of `N` Lux layers + + Specified as `N` keyword arguments. + +## Inputs + +Layer behaves differently based on input type: + + 1. If the input `x` is a tuple of length `N + 1`, then the `layers` must be a tuple of + length `N`. The computation is as follows + +```julia +y = x[1] +for i in 1:N + y = connection(x[i + 1], layers[i](y)) +end +``` + + 2. Any other kind of input + +```julia +y = x +for i in 1:N + y = connection(x, layers[i](y)) +end +``` + +## Returns + + - See Inputs section for how the return value is computed + - Updated model state for all the contained layers + +## Parameters + + - Parameters of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +## States + + - States of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) +""" +struct PairwiseFusion{F, T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} + connection::F + layers::T +end + +function PairwiseFusion(connection, layers...) + names = ntuple(i -> Symbol("layer_$i"), length(layers)) + return PairwiseFusion(connection, NamedTuple{names}(layers)) +end + +PairwiseFusion(connection; kwargs...) = PairwiseFusion(connection, (; kwargs...)) + +function (m::PairwiseFusion)(x, ps, st::NamedTuple) + return applypairwisefusion(m.layers, m.connection, x, ps, st) +end + +@generated function applypairwisefusion(layers::NamedTuple{names}, connection::C, x::T, ps, + st::NamedTuple) where {names, C, T} + N = length(names) + y_symbols = [gensym() for _ in 1:(N + 1)] + st_symbols = [gensym() for _ in 1:N] + getinput(i) = T <: Tuple ? :(x[$i]) : :x + calls = [:($(y_symbols[N + 1]) = $(getinput(1)))] + append!(calls, + [:(($(y_symbols[i]), $(st_symbols[i])) = layers.$(names[i])($(y_symbols[N + 1]), + ps.$(names[i]), + st.$(names[i])); + $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))) + for i in 1:N]) + push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) + push!(calls, :(return $(y_symbols[N + 1]), st)) + return Expr(:block, calls...) +end + +Base.keys(m::PairwiseFusion) = Base.keys(getfield(m, :layers)) + +""" + Chain(layers...; disable_optimizations::Bool = false) + Chain(; layers..., disable_optimizations::Bool = false) + +Collects multiple layers / functions to be called in sequence on a given input. + +## Arguments + + - Layers can be specified in two formats: + + + A list of `N` Lux layers + + Specified as `N` keyword arguments. + +## Keyword Arguments + + - `disable_optimizations`: Prevents any structural optimization + +## Inputs + +Input `x` is passed sequentially to each layer, and must conform to the input requirements +of the internal layers. + +## Returns + + - Output after sequentially applying all the layers to `x` + - Updated model states + +## Parameters + + - Parameters of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +## States + + - States of each `layer` wrapped in a NamedTuple with + `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +## Optimizations + +Performs a few optimizations to generate reasonable architectures. Can be disabled using +keyword argument `disable_optimizations`. + + - All sublayers are recursively optimized. + - If a function `f` is passed as a layer and it doesn't take 3 inputs, it is converted to + a [`WrappedFunction`](@ref)(`f`) which takes only one input. + - If the layer is a Chain, it is flattened. + - [`NoOpLayer`](@ref)s are removed. + - If there is only 1 layer (left after optimizations), then it is returned without the + `Chain` wrapper. + - If there are no layers (left after optimizations), a [`NoOpLayer`](@ref) is returned. + +## Example + +```julia +c = Chain(Dense(2, 3, relu), BatchNorm(3), Dense(3, 2)) +``` +""" +struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} + layers::T + + function Chain(xs...; disable_optimizations::Bool=false) + xs = disable_optimizations ? xs : _flatten_model(xs) + length(xs) == 0 && return NoOpLayer() + length(xs) == 1 && return first(xs) + names = ntuple(i -> Symbol("layer_$i"), length(xs)) + layers = NamedTuple{names}(xs) + return new{typeof(layers)}(layers) + end + + function Chain(xs::AbstractVector; disable_optimizations::Bool=false) + return Chain(xs...; disable_optimizations) + end + + function Chain(nt::NamedTuple; disable_optimizations::Bool=true) + if !disable_optimizations + throw(ArgumentError("Chain(::NamedTuple) is not compatible with" * + " disable_optimizations=true")) + end + return new{typeof(nt)}(nt) + end + + function Chain(; disable_optimizations::Bool=true, kwargs...) + return Chain((; kwargs...); disable_optimizations) + end +end + +function _flatten_model(layers::Union{AbstractVector, Tuple}) + new_layers = [] + for l in layers + f = _flatten_model(l) + if f isa Tuple || f isa AbstractVector + append!(new_layers, f) + elseif f isa Function + if !hasmethod(f, (Any, Union{ComponentArray, NamedTuple}, NamedTuple)) + if f === identity + continue + else + push!(new_layers, WrappedFunction(f)) + end + else + push!(new_layers, f) + end + elseif f isa Chain + append!(new_layers, f.layers) + elseif f isa NoOpLayer + continue + else + push!(new_layers, f) + end + end + return layers isa AbstractVector ? new_layers : Tuple(new_layers) +end + +_flatten_model(x) = x + +function (c::Chain)(x, ps, st::NamedTuple) + return applychain(c.layers, x, ps, st) +end + +@generated function applychain(layers::NamedTuple{fields}, x, ps, + st::NamedTuple{fields}) where {fields} + N = length(fields) + x_symbols = vcat([:x], [gensym() for _ in 1:N]) + st_symbols = [gensym() for _ in 1:N] + calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = layers.$(fields[i])($(x_symbols[i]), + ps.$(fields[i]), + st.$(fields[i]))) + for i in 1:N] + push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) + push!(calls, :(return $(x_symbols[N + 1]), st)) + return Expr(:block, calls...) +end + +Base.keys(m::Chain) = Base.keys(getfield(m, :layers)) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index aaf7acf0e9..a2828dbab1 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -81,131 +81,6 @@ Random.seed!(rng, 0) end end -@testset "Containers" begin - @testset "SkipConnection" begin - @testset "zero sum" begin - layer = SkipConnection(WrappedFunction(zero), (a, b) -> a .+ b) - println(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 10, 10, 10) - - @test layer(x, ps, st)[1] == x - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - end - - @testset "concat size" begin - layer = SkipConnection(Dense(10, 10), (a, b) -> hcat(a, b)) - println(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 2) - - @test size(layer(x, ps, st)[1]) == (10, 4) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end - end - - @testset "Parallel" begin - @testset "zero sum" begin - layer = Parallel(+, WrappedFunction(zero), NoOpLayer()) - println(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 10, 10, 10) - - @test layer(x, ps, st)[1] == x - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - end - - @testset "concat size" begin - layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) - println(layer) - ps, st = Lux.setup(rng, layer) - x = randn(rng, 10, 2) - - @test size(layer(x, ps, st)[1]) == (10, 4) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - - layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) - println(layer) - ps, st = Lux.setup(rng, layer) - - @test size(layer(x, ps, st)[1]) == (10, 4) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, - rtol=1.0f-3) - end - - @testset "vararg input" begin - layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) - println(layer) - ps, st = Lux.setup(rng, layer) - x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) - - @test size(layer(x, ps, st)[1]) == (2, 1) - run_JET_tests(layer, x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol=1.0f-3, rtol=1.0f-3) - end - - @testset "connection is called once" begin - CNT = Ref(0) - f_cnt = (x...) -> (CNT[] += 1; +(x...)) - layer = Parallel(f_cnt, WrappedFunction(sin), WrappedFunction(cos), - WrappedFunction(tan)) - ps, st = Lux.setup(rng, layer) - Lux.apply(layer, 1, ps, st) - @test CNT[] == 1 - run_JET_tests(layer, 1, ps, st) - Lux.apply(layer, (1, 2, 3), ps, st) - @test CNT[] == 2 - layer = Parallel(f_cnt, WrappedFunction(sin)) - Lux.apply(layer, 1, ps, st) - @test CNT[] == 3 - end - - # Ref https://github.com/FluxML/Flux.jl/issues/1673 - @testset "Input domain" begin - struct Input - x::Any - end - - struct L1 <: Lux.AbstractExplicitLayer end - (::L1)(x, ps, st) = (ps.x * x, st) - Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) - Base.:*(a::AbstractArray, b::Input) = a * b.x - - par = Parallel(+, L1(), L1()) - ps, st = Lux.setup(rng, par) - - ip = Input(rand(Float32, 3, 3)) - ip2 = Input(rand(Float32, 3, 3)) - - @test par(ip, ps, st)[1] ≈ - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + - par.layers[2](ip.x, ps.layer_2, st.layer_2)[1] - @test par((ip, ip2), ps, st)[1] ≈ - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + - par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1] - gs = gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) - gs_reg = gradient(ps, ip, ip2) do p, x, y - return sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + - par.layers[2](y.x, p.layer_2, st.layer_2)[1]) - end - - @test gs[1] ≈ gs_reg[1] - @test gs[2].x ≈ gs_reg[2].x - @test gs[3].x ≈ gs_reg[3].x - end - end -end - @testset "Dense" begin @testset "constructors" begin layer = Dense(10, 100) diff --git a/test/layers/containers.jl b/test/layers/containers.jl new file mode 100644 index 0000000000..1845aed98d --- /dev/null +++ b/test/layers/containers.jl @@ -0,0 +1,238 @@ +using Lux, NNlib, Random, Test + +include("../test_utils.jl") + +rng = Random.default_rng() +Random.seed!(rng, 0) + +@testset "SkipConnection" begin + @testset "zero sum" begin + layer = SkipConnection(WrappedFunction(zero), (a, b) -> a .+ b) + println(layer) + ps, st = Lux.setup(rng, layer) + x = randn(rng, 10, 10, 10, 10) + + @test layer(x, ps, st)[1] == x + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) + end + + @testset "concat size" begin + layer = SkipConnection(Dense(10, 10), (a, b) -> hcat(a, b)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = randn(rng, 10, 2) + + @test size(layer(x, ps, st)[1]) == (10, 4) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol=1.0f-3, rtol=1.0f-3) + end +end + +@testset "Parallel" begin + @testset "zero sum" begin + layer = Parallel(+, WrappedFunction(zero), NoOpLayer()) + println(layer) + ps, st = Lux.setup(rng, layer) + x = randn(rng, 10, 10, 10, 10) + + @test layer(x, ps, st)[1] == x + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) + end + + @testset "concat size" begin + layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) + println(layer) + ps, st = Lux.setup(rng, layer) + x = randn(rng, 10, 2) + + @test size(layer(x, ps, st)[1]) == (10, 4) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) + + layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) + println(layer) + ps, st = Lux.setup(rng, layer) + + @test size(layer(x, ps, st)[1]) == (10, 4) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) + end + + @testset "vararg input" begin + layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) + + @test size(layer(x, ps, st)[1]) == (2, 1) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol=1.0f-3, rtol=1.0f-3) + end + + @testset "named layers" begin + layer = Parallel(+; d102=Dense(10, 2), d52=Dense(5, 2), d42=Dense(4, 2)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = (randn(rng, 10, 1), randn(rng, 5, 1), randn(rng, 4, 1)) + + @test size(layer(x, ps, st)[1]) == (2, 1) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol=1.0f-3, rtol=1.0f-3) + end + + @testset "connection is called once" begin + CNT = Ref(0) + f_cnt = (x...) -> (CNT[] += 1; +(x...)) + layer = Parallel(f_cnt, WrappedFunction(sin), WrappedFunction(cos), + WrappedFunction(tan)) + ps, st = Lux.setup(rng, layer) + Lux.apply(layer, 1, ps, st) + @test CNT[] == 1 + run_JET_tests(layer, 1, ps, st) + Lux.apply(layer, (1, 2, 3), ps, st) + @test CNT[] == 2 + layer = Parallel(f_cnt, WrappedFunction(sin)) + Lux.apply(layer, 1, ps, st) + @test CNT[] == 3 + end + + # Ref https://github.com/FluxML/Flux.jl/issues/1673 + @testset "Input domain" begin + struct Input + x::Any + end + + struct L1 <: Lux.AbstractExplicitLayer end + (::L1)(x, ps, st) = (ps.x * x, st) + Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) + Base.:*(a::AbstractArray, b::Input) = a * b.x + + par = Parallel(+, L1(), L1()) + ps, st = Lux.setup(rng, par) + + ip = Input(rand(Float32, 3, 3)) + ip2 = Input(rand(Float32, 3, 3)) + + @test par(ip, ps, st)[1] ≈ + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip.x, ps.layer_2, st.layer_2)[1] + @test par((ip, ip2), ps, st)[1] ≈ + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1] + gs = gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) + gs_reg = gradient(ps, ip, ip2) do p, x, y + return sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + + par.layers[2](y.x, p.layer_2, st.layer_2)[1]) + end + + @test gs[1] ≈ gs_reg[1] + @test gs[2].x ≈ gs_reg[2].x + @test gs[3].x ≈ gs_reg[3].x + end +end + +@testset "PairwiseFusion" begin + x = (rand(Float32, 1, 10), rand(Float32, 30, 10), rand(Float32, 10, 10)) + layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10)) + println(layer) + ps, st = Lux.setup(rng, layer) + y, _ = layer(x, ps, st) + @test size(y) == (10, 10) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, + rtol=1.0f-3) + + layer = PairwiseFusion(+; d1=Dense(1, 30), d2=Dense(30, 10)) + println(layer) + ps, st = Lux.setup(rng, layer) + y, _ = layer(x, ps, st) + @test size(y) == (10, 10) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, + rtol=1.0f-3) + + x = rand(1, 10) + layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1)) + println(layer) + ps, st = Lux.setup(rng, layer) + y, _ = layer(x, ps, st) + @test size(y) == (1, 10) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, + rtol=1.0f-3) + + layer = PairwiseFusion(vcat, WrappedFunction(x -> x .+ 1), WrappedFunction(x -> x .+ 2), + WrappedFunction(x -> x .^ 3)) + println(layer) + ps, st = Lux.setup(rng, layer) + @test layer((2, 10, 20, 40), ps, st)[1] == [125, 1728, 8000, 40] + + layer = PairwiseFusion(vcat, WrappedFunction(x -> x .+ 1), WrappedFunction(x -> x .+ 2), + WrappedFunction(x -> x .^ 3)) + println(layer) + ps, st = Lux.setup(rng, layer) + @test layer(7, ps, st)[1] == [1000, 729, 343, 7] +end + +@testset "BranchLayer" begin + layer = BranchLayer(Dense(10, 10), Dense(10, 10)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = rand(Float32, 10, 1) + (y1, y2), _ = layer(x, ps, st) + @test size(y1) == (10, 1) + @test size(y2) == (10, 1) + @test y1 == layer.layers.layer_1(x, ps.layer_1, st.layer_1)[1] + @test y2 == layer.layers.layer_2(x, ps.layer_2, st.layer_2)[1] + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(sum, layer(x, ps, st)[1]), x, ps; + atol=1.0f-3, rtol=1.0f-3) + + layer = BranchLayer(; d1=Dense(10, 10), d2=Dense(10, 10)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = rand(Float32, 10, 1) + (y1, y2), _ = layer(x, ps, st) + @test size(y1) == (10, 1) + @test size(y2) == (10, 1) + @test y1 == layer.layers.d1(x, ps.d1, st.d1)[1] + @test y2 == layer.layers.d2(x, ps.d2, st.d2)[1] + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(sum, layer(x, ps, st)[1]), x, ps; + atol=1.0f-3, rtol=1.0f-3) +end + +@testset "Chain" begin + layer = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh), Dense(2 => 1)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = rand(Float32, 10, 1) + y, _ = layer(x, ps, st) + @test size(y) == (1, 1) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, + rtol=1.0f-3) + + layer = Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), d21=Dense(2 => 1)) + println(layer) + ps, st = Lux.setup(rng, layer) + x = rand(Float32, 10, 1) + y, _ = layer(x, ps, st) + @test size(y) == (1, 1) + run_JET_tests(layer, x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, + rtol=1.0f-3) + + @test_throws ArgumentError Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), + d21=Dense(2 => 1), d2=Dense(2 => 1), + disable_optimizations=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index f405e968e0..d1d04eedfe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,6 +32,7 @@ end @testset "Layers" begin @time @safetestset "Basic" begin include("layers/basic.jl") end + @time @safetestset "Containers" begin include("layers/containers.jl") end @time @safetestset "Convolution" begin include("layers/conv.jl") end @time @safetestset "Normalization" begin include("layers/normalize.jl") end @time @safetestset "Recurrent" begin include("layers/recurrent.jl") end diff --git a/test/test_utils.jl b/test/test_utils.jl index 34f0284e5f..56ffb2d85c 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -5,6 +5,10 @@ function Base.isapprox(x, y; kwargs...) return x == y end +function Base.isapprox(x::Tuple, y::Tuple; kwargs...) + return all(isapprox.(x, y; kwargs...)) +end + function Base.isapprox(x::Optimisers.Leaf, y::Optimisers.Leaf; kwargs...) return isapprox(x.rule, y.rule; kwargs...) && isapprox(x.state, y.state; kwargs...) end