-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add (naive) FFBS algo #20
Draft
FredericWantiez
wants to merge
13
commits into
fred/auxiliary-particle-filter
Choose a base branch
from
fred/ffbs
base: fred/auxiliary-particle-filter
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+418
−18
Draft
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
440cdc8
Add auxiliary particle filter
FredericWantiez 2c33c80
Mean transition
FredericWantiez 5198e3c
Merge conflict
FredericWantiez 9254898
test
FredericWantiez fa84e14
Add FFBS draft, change filter interface
FredericWantiez 4b948db
Fix backward weights:
FredericWantiez 556babd
Weigghts:
FredericWantiez 20faf22
Split trajectory
FredericWantiez 9f03987
format
FredericWantiez 899c507
Clean up
FredericWantiez 51fd883
Ambiguous
FredericWantiez d4e2a40
fixed ancestry callbacks
charlesknipp 756d1f0
add guided particle filer
charlesknipp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
export FFBS | ||
|
||
abstract type AbstractSmoother <: AbstractSampler end | ||
|
||
struct FFBS{T<:AbstractParticleFilter} | ||
filter::T | ||
end | ||
|
||
""" | ||
smooth(rng::AbstractRNG, alg::AbstractSmooterh, model::AbstractStateSpaceModel, obs::AbstractVector, M::Integer; callback, kwargs...) | ||
""" | ||
function smooth end | ||
|
||
struct WeightedParticleRecorderCallback{T,WT} | ||
particles::Array{T} | ||
log_weights::Array{WT} | ||
end | ||
|
||
function (callback::WeightedParticleRecorderCallback)( | ||
model, filter, step, states, data; kwargs... | ||
) | ||
filtered_states = states.filtered | ||
callback.particles[step, :] = filtered_states.particles | ||
callback.log_weights[step, :] = filtered_states.log_weights | ||
return nothing | ||
end | ||
|
||
function gen_trajectory( | ||
rng::Random.AbstractRNG, | ||
model::StateSpaceModel, | ||
particles::AbstractMatrix{T}, # Need better container | ||
log_weights::AbstractMatrix{WT}, | ||
forward_state, | ||
n_timestep::Int; | ||
kwargs..., | ||
) where {T,WT} | ||
trajectory = Vector{T}(undef, n_timestep) | ||
trajectory[end] = forward_state | ||
for step in (n_timestep - 1):-1:1 | ||
backward_weights = backward( | ||
model, | ||
step, | ||
trajectory[step + 1], | ||
particles[step, :], | ||
log_weights[step, :]; | ||
kwargs..., | ||
) | ||
ancestor = rand(rng, Categorical(softmax(backward_weights))) | ||
trajectory[step] = particles[step, ancestor] | ||
end | ||
return trajectory | ||
end | ||
|
||
function backward( | ||
model::StateSpaceModel, step::Integer, state, particles::T, log_weights::WT; kwargs... | ||
) where {T,WT} | ||
transitions = map(particles) do prev_state | ||
SSMProblems.logdensity(model.dyn, step, prev_state, state; kwargs...) | ||
end | ||
return log_weights + transitions | ||
end | ||
|
||
function sample( | ||
rng::Random.AbstractRNG, | ||
model::StateSpaceModel{T,LDT}, | ||
alg::FFBS{<:BootstrapFilter{N}}, | ||
obs::AbstractVector, | ||
M::Integer; | ||
callback=nothing, | ||
kwargs..., | ||
) where {T,LDT,N} | ||
n_timestep = length(obs) | ||
recorder = WeightedParticleRecorderCallback( | ||
Array{eltype(model.dyn)}(undef, n_timestep, N), Array{T}(undef, n_timestep, N) | ||
) | ||
|
||
particles, _ = filter(rng, model, alg.filter, obs; callback=recorder, kwargs...) | ||
|
||
# Backward sampling - exact | ||
idx_ref = rand(rng, Categorical(weights(particles.filtered)), M) | ||
trajectories = Array{eltype(model.dyn)}(undef, n_timestep, M) | ||
|
||
trajectories[end, :] = particles.filtered[idx_ref] | ||
for j in 1:M | ||
trajectories[:, j] = gen_trajectory( | ||
rng, | ||
model, | ||
recorder.particles, | ||
recorder.log_weights, | ||
trajectories[end, j], | ||
n_timestep, | ||
) | ||
end | ||
return trajectories | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
export GuidedFilter, GPF, AbstractProposal | ||
|
||
## PROPOSALS ############################################################################### | ||
""" | ||
AbstractProposal | ||
""" | ||
abstract type AbstractProposal end | ||
|
||
function SSMProblems.distribution( | ||
model::AbstractStateSpaceModel, | ||
prop::AbstractProposal, | ||
step::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
return throw( | ||
MethodError( | ||
SSMProblems.distribution, (model, prop, step, state, observation, kwargs...) | ||
), | ||
) | ||
end | ||
|
||
function SSMProblems.simulate( | ||
rng::AbstractRNG, | ||
model::AbstractStateSpaceModel, | ||
prop::AbstractProposal, | ||
step::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
return rand( | ||
rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...) | ||
) | ||
end | ||
|
||
function SSMProblems.logdensity( | ||
model::AbstractStateSpaceModel, | ||
prop::AbstractProposal, | ||
step::Integer, | ||
prev_state, | ||
new_state, | ||
observation; | ||
kwargs..., | ||
) | ||
return logpdf( | ||
SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...), | ||
new_state, | ||
) | ||
end | ||
|
||
## GUIDED FILTERING ######################################################################## | ||
|
||
struct GuidedFilter{N,RS<:AbstractResampler,P<:AbstractProposal} <: | ||
AbstractParticleFilter{N} | ||
resampler::RS | ||
proposal::P | ||
end | ||
|
||
function GuidedFilter( | ||
N::Integer, proposal::P; threshold::Real=1.0, resampler::AbstractResampler=Systematic() | ||
) where {P<:AbstractProposal} | ||
conditional_resampler = ESSResampler(threshold, resampler) | ||
return GuidedFilter{N,typeof(conditional_resampler),P}(conditional_resampler, proposal) | ||
end | ||
|
||
"""Shorthand for `GuidedFilter`""" | ||
const GPF = GuidedFilter | ||
|
||
function initialise( | ||
rng::AbstractRNG, | ||
model::StateSpaceModel{T}, | ||
filter::GuidedFilter{N}; | ||
ref_state::Union{Nothing,AbstractVector}=nothing, | ||
kwargs..., | ||
) where {N,T} | ||
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) | ||
initial_weights = zeros(T, N) | ||
|
||
return update_ref!( | ||
ParticleContainer(initial_states, initial_weights), ref_state, filter | ||
) | ||
end | ||
|
||
function predict( | ||
rng::AbstractRNG, | ||
model::StateSpaceModel, | ||
filter::GuidedFilter, | ||
step::Integer, | ||
states::ParticleContainer{T}, | ||
observation; | ||
ref_state::Union{Nothing,AbstractVector{T}}=nothing, | ||
kwargs..., | ||
) where {T} | ||
states.proposed, states.ancestors = resample( | ||
rng, filter.resampler, states.filtered, filter | ||
) | ||
states.proposed.particles = map( | ||
x -> SSMProblems.simulate( | ||
rng, model, filter.proposal, step, x, observation; kwargs... | ||
), | ||
collect(states.proposed), | ||
) | ||
|
||
return update_ref!(states, ref_state, filter, step) | ||
end | ||
|
||
function update( | ||
model::StateSpaceModel{T}, | ||
filter::GuidedFilter{N}, | ||
step::Integer, | ||
states::ParticleContainer, | ||
observation; | ||
kwargs..., | ||
) where {T,N} | ||
# this is a little messy and may require a deepcopy | ||
particle_collection = zip( | ||
collect(states.proposed), deepcopy(states.filtered.particles[states.ancestors]) | ||
) | ||
|
||
log_increments = map(particle_collection) do (new_state, prev_state) | ||
log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) | ||
log_g = SSMProblems.logdensity(model.obs, step, new_state, observation; kwargs...) | ||
log_q = SSMProblems.logdensity( | ||
model, filter.proposal, step, prev_state, new_state, observation; kwargs... | ||
) | ||
|
||
# println(log_f) | ||
|
||
(log_f + log_g - log_q) | ||
end | ||
|
||
# println(logsumexp(log_increments)) | ||
|
||
states.filtered.log_weights = states.proposed.log_weights + log_increments | ||
states.filtered.particles = states.proposed.particles | ||
|
||
return states, logmarginal(states, filter) | ||
end | ||
|
||
function step( | ||
rng::AbstractRNG, | ||
model::AbstractStateSpaceModel, | ||
alg::GuidedFilter, | ||
iter::Integer, | ||
state, | ||
observation; | ||
kwargs..., | ||
) | ||
proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) | ||
filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) | ||
|
||
return filtered_state, ll | ||
end | ||
|
||
function reset_weights!(state::ParticleState{T,WT}, idxs, ::GuidedFilter) where {T,WT<:Real} | ||
fill!(state.log_weights, zero(WT)) | ||
return state | ||
end | ||
|
||
function logmarginal(states::ParticleContainer, ::GuidedFilter) | ||
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should change the
SSMProblems
interface. The proposal should be part of thefilter
interface, maybe something along the lines of:And we should probably update the
filter/predict
functions:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I 100% agree with the BF integration, I was intentionally working my way up to that, but didn't want to drastically change the interface upon the first commit.
And you're totally right about the SSMProblems integration. But it was convenient to recycle the structures.