Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add (naive) FFBS algo #20

Draft
wants to merge 13 commits into
base: fred/auxiliary-particle-filter
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,7 @@ include("algorithms/apf.jl")
include("algorithms/kalman.jl")
include("algorithms/forward.jl")
include("algorithms/rbpf.jl")
include("algorithms/ffbs.jl")
include("algorithms/guidedpf.jl")

end
14 changes: 9 additions & 5 deletions src/algorithms/apf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: Ab
end

function AuxiliaryParticleFilter(
N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic()
N::Integer; threshold::Real=0.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N))
Expand All @@ -24,7 +24,9 @@ function initialise(
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter)
return update_ref!(
ParticleContainer(initial_states, initial_weights), ref_state, filter
)
end

function update_weights!(
Expand Down Expand Up @@ -59,7 +61,9 @@ function predict(
states.filtered.log_weights .+= auxiliary_weights
filter.aux = auxiliary_weights

states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter)
states.proposed, states.ancestors = resample(
rng, filter.resampler, states.filtered, filter
)
states.proposed.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
states.proposed.particles,
Expand All @@ -70,12 +74,12 @@ end

function update(
model::StateSpaceModel{T},
filter::AuxiliaryParticleFilter,
filter::AuxiliaryParticleFilter{N},
step::Integer,
states::ParticleContainer,
observation;
kwargs...,
) where {T}
) where {T,N}
@debug "step $step"
log_increments = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
Expand Down
18 changes: 11 additions & 7 deletions src/algorithms/bootstrap.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
export BootstrapFilter, BF

abstract type AbstractParticleFilter{N} <: AbstractFilter end

struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N}
resampler::RS
end
Expand All @@ -8,7 +10,7 @@ function BootstrapFilter(
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler)
return BootstrapFilter{N,typeof(conditional_resampler)}(conditional_resampler)
end

"""Shorthand for `BootstrapFilter`"""
Expand Down Expand Up @@ -38,7 +40,9 @@ function predict(
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
) where {T}
states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter)
states.proposed, states.ancestors = resample(
rng, filter.resampler, states.filtered, filter
)
states.proposed.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...),
collect(states.proposed),
Expand All @@ -49,12 +53,12 @@ end

function update(
model::StateSpaceModel{T},
filter::BootstrapFilter,
filter::BootstrapFilter{N},
step::Integer,
states::ParticleContainer,
observation;
kwargs...,
) where {T}
) where {T,N}
log_increments = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
collect(states.proposed),
Expand All @@ -67,12 +71,12 @@ function update(
end

function reset_weights!(
state::ParticleState{T,WT}, idxs, filter::BootstrapFilter
state::ParticleState{T,WT}, idxs, ::BootstrapFilter
) where {T,WT<:Real}
fill!(state.log_weights, -log(WT(length(state.particles))))
fill!(state.log_weights, zero(WT))
return state
end

function logmarginal(states::ParticleContainer, ::BootstrapFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end
end
95 changes: 95 additions & 0 deletions src/algorithms/ffbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
export FFBS

abstract type AbstractSmoother <: AbstractSampler end

struct FFBS{T<:AbstractParticleFilter}
filter::T
end

"""
smooth(rng::AbstractRNG, alg::AbstractSmooterh, model::AbstractStateSpaceModel, obs::AbstractVector, M::Integer; callback, kwargs...)
"""
function smooth end

struct WeightedParticleRecorderCallback{T,WT}
particles::Array{T}
log_weights::Array{WT}
end

function (callback::WeightedParticleRecorderCallback)(
model, filter, step, states, data; kwargs...
)
filtered_states = states.filtered
callback.particles[step, :] = filtered_states.particles
callback.log_weights[step, :] = filtered_states.log_weights
return nothing
end

function gen_trajectory(
rng::Random.AbstractRNG,
model::StateSpaceModel,
particles::AbstractMatrix{T}, # Need better container
log_weights::AbstractMatrix{WT},
forward_state,
n_timestep::Int;
kwargs...,
) where {T,WT}
trajectory = Vector{T}(undef, n_timestep)
trajectory[end] = forward_state
for step in (n_timestep - 1):-1:1
backward_weights = backward(
model,
step,
trajectory[step + 1],
particles[step, :],
log_weights[step, :];
kwargs...,
)
ancestor = rand(rng, Categorical(softmax(backward_weights)))
trajectory[step] = particles[step, ancestor]
end
return trajectory
end

function backward(
model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs...
) where {T,WT}
transitions = map(particles) do prev_state
SSMProblems.logdensity(model.dyn, step, prev_state, state; kwargs...)
end
return log_weights + transitions
end

function sample(
rng::Random.AbstractRNG,
model::StateSpaceModel{T,LDT},
alg::FFBS{<:BootstrapFilter{N}},
obs::AbstractVector,
M::Integer;
callback=nothing,
kwargs...,
) where {T,LDT,N}
n_timestep = length(obs)
recorder = WeightedParticleRecorderCallback(
Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N)
)

particles, _ = filter(rng, model, alg.filter, obs; callback=recorder, kwargs...)

# Backward sampling - exact
idx_ref = rand(rng, Categorical(weights(particles.filtered)), M)
trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M)

trajectories[end, :] = particles.filtered[idx_ref]
for j in 1:M
trajectories[:, j] = gen_trajectory(
rng,
model,
recorder.particles,
recorder.log_weights,
trajectories[end, j],
n_timestep,
)
end
return trajectories
end
164 changes: 164 additions & 0 deletions src/algorithms/guidedpf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
export GuidedFilter, GPF, AbstractProposal

## PROPOSALS ###############################################################################
"""
AbstractProposal
"""
abstract type AbstractProposal end

function SSMProblems.distribution(
model::AbstractStateSpaceModel,
prop::AbstractProposal,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should change the SSMProblems interface. The proposal should be part of the filter interface, maybe something along the lines of:

abstract type AbstractProposal end

abstract type AbstractParticleFilter{N, P<:AbstractProposal} end 

struct ParticleFilter{N,RS,P} <: AbstractParticleFilter{N,P}
    resampler::RS
    proposal::P
end

# Default to latent dynamics
struct LatentProposal <: AbstractProposal end

const BootstrapFilter{N,RS} = ParticleFilter{N,RS,LatentProposal}
const BF = BootstrapFilter

function propose(
    rng::AbstractRNG, 
    prop::LatentProposal, 
    model::AbstractStateSpaceModel, 
    particles::ParticleContainer, 
    step, 
    state, 
    obs; 
    kwargs...
)
    return SSMProblems.simulate(rng, model.dyn, t, state; kwargs...)
end

function logdensity(prop::AbstractProposal, ...)
   return SSMProblmes.logdensity(...)
end

And we should probably update the filter/predict functions:

function predict(
    rng::AbstractRNG,
    model::StateSpaceModel,
    filter::BootstrapFilter,
    step::Integer,
    states::ParticleContainer{T};
    ref_state::Union{Nothing,AbstractVector{T}}=nothing,
    kwargs...,
) where {T}
    states.proposed, states.ancestors = resample(
        rng, filter.resampler, states.filtered, filter
    )
    states.proposed.particles = map(states.proposed) do state
        propose(rng, filter.proposal, model.dyn, step, state; kwargs...),
    end

    return update_ref!(states, ref_state, filter, step)
end

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I 100% agree with the BF integration, I was intentionally working my way up to that, but didn't want to drastically change the interface upon the first commit.

And you're totally right about the SSMProblems integration. But it was convenient to recycle the structures.

step::Integer,
state,
observation;
kwargs...,
)
return throw(
MethodError(
SSMProblems.distribution, (model, prop, step, state, observation, kwargs...)
),
)
end

function SSMProblems.simulate(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
prop::AbstractProposal,
step::Integer,
state,
observation;
kwargs...,
)
return rand(
rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...)
)
end

function SSMProblems.logdensity(
model::AbstractStateSpaceModel,
prop::AbstractProposal,
step::Integer,
prev_state,
new_state,
observation;
kwargs...,
)
return logpdf(
SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...),
new_state,
)
end

## GUIDED FILTERING ########################################################################

struct GuidedFilter{N,RS<:AbstractResampler,P<:AbstractProposal} <:
AbstractParticleFilter{N}
resampler::RS
proposal::P
end

function GuidedFilter(
N::Integer, proposal::P; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
) where {P<:AbstractProposal}
conditional_resampler = ESSResampler(threshold, resampler)
return GuidedFilter{N,typeof(conditional_resampler),P}(conditional_resampler, proposal)
end

"""Shorthand for `GuidedFilter`"""
const GPF = GuidedFilter

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::GuidedFilter{N};
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(
ParticleContainer(initial_states, initial_weights), ref_state, filter
)
end

function predict(
rng::AbstractRNG,
model::StateSpaceModel,
filter::GuidedFilter,
step::Integer,
states::ParticleContainer{T},
observation;
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
) where {T}
states.proposed, states.ancestors = resample(
rng, filter.resampler, states.filtered, filter
)
states.proposed.particles = map(
x -> SSMProblems.simulate(
rng, model, filter.proposal, step, x, observation; kwargs...
),
collect(states.proposed),
)

return update_ref!(states, ref_state, filter, step)
end

function update(
model::StateSpaceModel{T},
filter::GuidedFilter{N},
step::Integer,
states::ParticleContainer,
observation;
kwargs...,
) where {T,N}
# this is a little messy and may require a deepcopy
particle_collection = zip(
collect(states.proposed), deepcopy(states.filtered.particles[states.ancestors])
)

log_increments = map(particle_collection) do (new_state, prev_state)
log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...)
log_g = SSMProblems.logdensity(model.obs, step, new_state, observation; kwargs...)
log_q = SSMProblems.logdensity(
model, filter.proposal, step, prev_state, new_state, observation; kwargs...
)

# println(log_f)

(log_f + log_g - log_q)
end

# println(logsumexp(log_increments))

states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles

return states, logmarginal(states, filter)
end

function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::GuidedFilter,
iter::Integer,
state,
observation;
kwargs...,
)
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...)
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...)

return filtered_state, ll
end

function reset_weights!(state::ParticleState{T,WT}, idxs, ::GuidedFilter) where {T,WT<:Real}
fill!(state.log_weights, zero(WT))
return state
end

function logmarginal(states::ParticleContainer, ::GuidedFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end
4 changes: 2 additions & 2 deletions src/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ end
function (c::AncestorCallback)(model, filter, step, states, data; kwargs...)
if step == 1
# this may be incorrect, but it is functional
@inbounds c.tree.states[1:(filter.N)] = deepcopy(states.filtered.particles)
@inbounds c.tree.states[keys(states.filtered)] = deepcopy(states.filtered.particles)
end
# TODO: when using non-stack version, may be more efficient to wait until storage full
# to prune
Expand Down Expand Up @@ -304,7 +304,7 @@ end
function (c::ResamplerCallback)(model, filter, step, states, data; kwargs...)
if step != 1
prune!(c.tree, get_offspring(states.ancestors))
insert!(c.tree, collect(1:(filter.N)), states.ancestors)
insert!(c.tree, collect(keys(states.filtered)), states.ancestors)
end
return nothing
end
Loading