Skip to content

Commit

Permalink
Merge pull request #287 from avik-pal/ap/return_seq
Browse files Browse the repository at this point in the history
Return the history for Recurrence
  • Loading branch information
avik-pal authored Mar 15, 2023
2 parents a19a11c + ed64a78 commit 8a82e51
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.4.43"
version = "0.4.44"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
1 change: 0 additions & 1 deletion src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ function adapt_storage(::LuxCPUAdaptor,
end
adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng
# TODO(@avik-pal): SparseArrays
function adapt_storage(::LuxCPUAdaptor, x::CUDA.CUSPARSE.AbstractCuSparseMatrix)
return adapt(Array, x)
end
Expand Down
10 changes: 10 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ function CRC.rrule(::typeof(copy), x)
return copy(x), copy_pullback
end

function CRC.rrule(::typeof(_eachslice), x, d)
return _eachslice(x, d), Δ -> (NoTangent(), ∇_eachslice(Δ, x, d), NoTangent())
end

# Adapt Interface
function CRC.rrule(::Type{Array}, x::CUDA.CuArray)
return Array(x), d -> (NoTangent(), CUDA.cu(d))
Expand Down Expand Up @@ -75,3 +79,9 @@ function CRC.rrule(::typeof(multigate), x::AbstractArray, c::Val{N}) where {N}
end
return multigate(x, c), multigate_pullback
end

# layers/recurrent.jl
function CRC.rrule(::typeof(_generate_init_recurrence), out, carry, state)
result = _generate_init_recurrence(out, carry, state)
return result, Δ -> (NoTangent(), ∇_generate_init_recurrence(Δ)...)
end
50 changes: 34 additions & 16 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO(@avik-pal): We can add another subtype `AbstractRecurrentCell` in the type hierarchy
# to make it safer to compose these cells with `Recurrence`
"""
Recurrence(cell)
Recurrence(cell; return_sequence::Bool = false)
Wraps a recurrent cell (like [`RNNCell`](@ref), [`LSTMCell`](@ref), [`GRUCell`](@ref)) to
automatically operate over a sequence of inputs.
Expand All @@ -17,6 +17,11 @@ automatically operate over a sequence of inputs.
- `cell`: A recurrent cell. See [`RNNCell`](@ref), [`LSTMCell`](@ref), [`GRUCell`](@ref),
for how the inputs/outputs of a recurrent cell must be structured.
## Keyword Arguments
- `return_sequence`: If `true` returns the entire sequence of outputs, else returns only
the last output. Defaults to `false`.
## Inputs
- If `x` is a
Expand All @@ -39,31 +44,42 @@ automatically operate over a sequence of inputs.
- Same as `cell`.
"""
struct Recurrence{C <: AbstractExplicitLayer} <: AbstractExplicitContainerLayer{(:cell,)}
struct Recurrence{R, C <: AbstractExplicitLayer} <: AbstractExplicitContainerLayer{(:cell,)}
cell::C
end

function Recurrence(cell; return_sequence::Bool=false)
return Recurrence{return_sequence, typeof(cell)}(cell)
end

@inline function (r::Recurrence)(x::A, ps, st::NamedTuple) where {A <: AbstractArray}
return Lux.apply(r, _eachslice(x, Val(ndims(x) - 1)), ps, st)
end

# Non-ideal dispatch since we can't unroll the loop easily while the code is exactly same
# as NTuple
function (r::Recurrence)(x::AbstractVector, ps, st::NamedTuple)
(out, carry), st = r.cell(first(x), ps, st)
for x_ in x[2:end]
(out, carry), st = r.cell((x_, carry), ps, st)
function (r::Recurrence{false})(x::Union{AbstractVector, NTuple}, ps, st::NamedTuple)
(out, carry), st = Lux.apply(r.cell, first(x), ps, st)
for x_ in x[(begin + 1):end]
(out, carry), st = Lux.apply(r.cell, (x_, carry), ps, st)
end
return out, st
end

@generated function (r::Recurrence)(x::NTuple{N}, ps, st::NamedTuple) where {N}
quote
(out, carry), st = r.cell(x[1], ps, st)
Base.Cartesian.@nexprs $(N - 1) i->((out, carry), st) = r.cell((x[i + 1], carry),
ps, st)
return out, st
# FIXME: Weird Hack
_generate_init_recurrence(out, carry, st) = (typeof(out)[out], carry, st)
∇_generate_init_recurrence((Δout, Δcarry, Δst)) = (first(Δout), Δcarry, Δst)

function (r::Recurrence{true})(x::Union{AbstractVector, NTuple}, ps, st::NamedTuple)
(out_, carry), st = Lux.apply(r.cell, first(x), ps, st)

init = _generate_init_recurrence(out_, carry, st)

function recurrence_op(input, (outputs, carry, state))
(out, carry), state = Lux.apply(r.cell, (input, carry), ps, state)
return vcat(outputs, typeof(out)[out]), carry, state
end

results = foldr(recurrence_op, x[(begin + 1):end]; init)
return first(results), last(results)
end

"""
Expand Down Expand Up @@ -118,8 +134,10 @@ function (r::StatefulRecurrentCell)(x, ps, st::NamedTuple)
return out, (; cell=st_, carry)
end

applyrecurrentcell(l::AbstractExplicitLayer, x, ps, st, carry) = l((x, carry), ps, st)
applyrecurrentcell(l::AbstractExplicitLayer, x, ps, st, ::Nothing) = l(x, ps, st)
function applyrecurrentcell(l::AbstractExplicitLayer, x, ps, st, carry)
return Lux.apply(l, (x, carry), ps, st)
end
applyrecurrentcell(l::AbstractExplicitLayer, x, ps, st, ::Nothing) = Lux.apply(l, x, ps, st)

@doc doc"""
RNNCell(in_dims => out_dims, activation=tanh; bias::Bool=true,
Expand Down
12 changes: 4 additions & 8 deletions src/stacktraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@ function disable_stacktrace_truncation!(; disable::Bool=true)
end

# NamedTuple -- Lux uses them quite frequenty (states) making the error messages too verbose
function Base.show(io::IO, ::Type{<:NamedTuple{fields, fTypes}}) where {fields, fTypes}
function Base.show(io::IO, t::Type{<:NamedTuple{fields, fTypes}}) where {fields, fTypes}
if TruncatedStacktraces.VERBOSE[]
print(io, "NamedTuple{$fields, $fTypes}")
invoke(show, Tuple{IO, Type}, io, t)
else
fields_truncated = if length(fields) > 2
"($(fields[1]), $(fields[2]), ...)"
else
fields
end
print(io, "NamedTuple{$fields_truncated, ...}")
fields_truncated = length(fields) > 2 ? "($(fields[1]),$(fields[2]),…)" : fields
print(io, "NamedTuple{$fields_truncated,…}")
end
end

Expand Down
18 changes: 17 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,26 @@ end
end
end

@inline function _eachslice(x::T, ::Val{dims}) where {T <: AbstractArray, dims}
@inline function _eachslice(x::AbstractArray, ::Val{dims}) where {dims}
return [selectdim(x, dims, i) for i in axes(x, dims)]
end

function ∇_eachslice(Δ_raw, x::AbstractArray, ::Val{dims}) where {dims}
Δs = CRC.unthunk(Δ_raw)
i1 = findfirst-> Δ isa AbstractArray, Δs)
i1 === nothing && zero.(x) # all slices are Zero!
Δ = similar(x)
for i in axes(x, dims)
Δi = selectdim(Δ, dims, i)
if Δi isa CRC.AbstractZero
fill!(Δi, 0)
else
copyto!(Δi, Δs[i])
end
end
return CRC.ProjectTo(x)(Δ)
end

# Backend Integration
## Convolution
@inline _conv(x, weight, cdims) = conv(x, weight, cdims)
Expand Down
48 changes: 17 additions & 31 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,40 +248,26 @@ end end

cell = _cell(3 => 5; use_bias, train_state)
rnn = Recurrence(cell)
rnn_seq = Recurrence(cell; return_sequence=true)
display(rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2)
ps, st = Lux.setup(rng, rnn)
y, st_ = rnn(x, ps, st)

run_JET_tests(rnn, x, ps, st)

@test size(y) == (5, 2)

test_gradient_correctness_fdm(p -> sum(rnn(x, p, st)[1]), ps; atol=1e-2, rtol=1e-2)

# Tuple of Time Series
x = Tuple(randn(rng, Float32, 3, 2) for _ in 1:4)
ps, st = Lux.setup(rng, rnn)
y, st_ = rnn(x, ps, st)

run_JET_tests(rnn, x, ps, st)

@test size(y) == (5, 2)

test_gradient_correctness_fdm(p -> sum(rnn(x, p, st)[1]), ps; atol=1e-2, rtol=1e-2)

# Vector of Time Series
x = [randn(rng, Float32, 3, 2) for _ in 1:4]
ps, st = Lux.setup(rng, rnn)
y, st_ = rnn(x, ps, st)

run_JET_tests(rnn, x, ps, st)

@test size(y) == (5, 2)

test_gradient_correctness_fdm(p -> sum(rnn(x, p, st)[1]), ps; atol=1e-2, rtol=1e-2)
for x in (randn(rng, Float32, 3, 4, 2), Tuple(randn(rng, Float32, 3, 2) for _ in 1:4),
[randn(rng, Float32, 3, 2) for _ in 1:4])
ps, st = Lux.setup(rng, rnn)
y, st_ = rnn(x, ps, st)
y_, st__ = rnn_seq(x, ps, st)
run_JET_tests(rnn, x, ps, st)
run_JET_tests(rnn_seq, x, ps, st)

@test size(y) == (5, 2)
@test length(y_) == 4
@test all(x -> size(x) == (5, 2), y_)

test_gradient_correctness_fdm(p -> sum(rnn(x, p, st)[1]), ps; atol=1e-2, rtol=1e-2)
test_gradient_correctness_fdm(p -> sum(Base.Fix1(sum, abs2), rnn_seq(x, p, st)[1]),
ps; atol=1e-2, rtol=1e-2)
end
end end

@testset "multigate" begin
Expand Down

2 comments on commit 8a82e51

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/79662

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.44 -m "<description of version>" 8a82e51ca2125e23fa5ac9601fb67f8a162f13bd
git push origin v0.4.44

Please sign in to comment.