Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented wrapper for KalmanFilters.jl SRKF #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ version = "0.1.0"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
KalmanFilters = "272a6111-cf0e-4c1b-a056-8d658cb314ee"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
112 changes: 105 additions & 7 deletions src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
export KalmanFilter, filter
using KalmanFilters: KalmanFilters
import LinearAlgebra: Cholesky, cholesky, diag, dot

export KalmanFilter, filter, SRKF

struct KalmanFilter <: FilteringAlgorithm end

Expand All @@ -13,9 +16,9 @@ function predict(
model::LinearGaussianStateSpaceModel{T},
filter::KalmanFilter,
step::Integer,
state::@NamedTuple{μ::Vector{T}, Σ::Matrix{T}},
state::@NamedTuple{μ::V, Σ::M},
extra,
) where {T}
) where {T,V<:AbstractVector{T},M<:AbstractMatrix{T}}
μ, Σ = state.μ, state.Σ
A, b, Q = calc_params(model.dyn, step, extra)
μ̂ = A * μ + b
Expand All @@ -27,10 +30,10 @@ function update(
model::LinearGaussianStateSpaceModel{T},
filter::KalmanFilter,
step::Integer,
state::@NamedTuple{μ::Vector{T}, Σ::Matrix{T}},
state::@NamedTuple{μ::V, Σ::M},
obs::Vector{T},
extra,
) where {T}
) where {T,V<:AbstractVector{T},M<:AbstractMatrix{T}}
μ, Σ = state.μ, state.Σ
H, c, R = calc_params(model.obs, step, extra)

Expand All @@ -43,6 +46,7 @@ function update(
Σ̂ = Σ - K * H * Σ

# Compute log-likelihood
S = (S + S') / 2 # HACK: ensure S is symmetric (due to numerical errors)
ll = logpdf(MvNormal(m, S), obs)

return (μ=μ̂, Σ=Σ̂), ll
Expand All @@ -52,10 +56,10 @@ function step(
model::LinearGaussianStateSpaceModel{T},
filter::KalmanFilter,
step::Integer,
state::@NamedTuple{μ::Vector{T}, Σ::Matrix{T}},
state::@NamedTuple{μ::V, Σ::M},
obs::Vector{T},
extra,
) where {T}
) where {T,V<:AbstractVector{T},M<:AbstractMatrix{T}}
state = predict(model, filter, step, state, extra)
state, ll = update(model, filter, step, state, obs, extra)
return state, ll
Expand All @@ -78,3 +82,97 @@ function filter(
end
return states, ll
end

"""
SRKF()

A square-root Kalman filter.

Implemented by wrapping KalmanFilters.jl.
"""
struct SRKF <: FilteringAlgorithm end

struct FactorisedGaussian{T<:Real,V<:AbstractVector{T},M<:Cholesky{T}}
μ::V
Σ::M
end

function initialise(model::LinearGaussianStateSpaceModel{T}, ::SRKF, extra) where {T}
μ0, Σ0 = calc_initial(model.dyn, extra)
return FactorisedGaussian(μ0, cholesky(Σ0))
end

function predict(
model::LinearGaussianStateSpaceModel{T},
::SRKF,
step::Integer,
state::FactorisedGaussian{T},
extra,
) where {T}
(; μ, Σ) = state
A, b, Q = calc_params(model.dyn, step, extra)
!all(b .== 0) && error("Non-zero b not supported for SRKF")
tu = KalmanFilters.time_update(μ, Σ, A, cholesky(Q))
return FactorisedGaussian(
KalmanFilters.get_state(tu), KalmanFilters.get_sqrt_covariance(tu)
)
end

function update(
model::LinearGaussianStateSpaceModel{T},
::SRKF,
step::Integer,
state::FactorisedGaussian{T},
obs::Vector{T},
extra,
) where {T}
(; μ, Σ) = state
H, c, R = calc_params(model.obs, step, extra)
!all(c .== 0) && error("Non-zero c not supported for SRKF")
mu = KalmanFilters.measurement_update(μ, Σ, obs, H, cholesky(R))
# Note: since Cholesky L came from QR decomposition, it may not have positive diagonals
ll = compute_ll(mu.innovation, mu.innovation_covariance)

return FactorisedGaussian(
KalmanFilters.get_state(mu), KalmanFilters.get_sqrt_covariance(mu)
),
ll
end

# Manual Gaussian likelihood computation valid for non-positive diagonals of L
function compute_ll(ỹ, Σ::Cholesky)
v = Σ.L \ ỹ
logdet = 2 * sum(log ∘ abs, diag(Σ.L)) # take abs to handle non-positive diagonals
return -0.5 * (dot(v, v) + logdet + length(ỹ) * log(2π))
end

function step(
model::LinearGaussianStateSpaceModel{T},
filter::SRKF,
step::Integer,
state::FactorisedGaussian{T},
obs::Vector{T},
extra,
) where {T}
state = predict(model, filter, step, state, extra)
state, ll = update(model, filter, step, state, obs, extra)
return state, ll
end

function filter(
model::LinearGaussianStateSpaceModel{T},
filter::SRKF,
data::Vector{Vector{T}},
extra0,
extras,
) where {T}
state = initialise(model, filter, extra0)
states = Vector{FactorisedGaussian{T}}(undef, length(data))
ll = 0.0
for (i, obs) in enumerate(data)
state, step_ll = step(model, filter, i, state, obs, extras[i])
states[i] = state
ll += step_ll
end
return states, ll
end
26 changes: 15 additions & 11 deletions src/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function calc_params(obs::LinearGaussianObservationProcess, step::Integer, extra
end

const LinearGaussianStateSpaceModel{T} = SSMProblems.StateSpaceModel{
D,O
T,D,O
} where {T,D<:LinearGaussianLatentDynamics{T},O<:LinearGaussianObservationProcess{T}}

# TODO: this is hacky and should ideally be removed
Expand Down Expand Up @@ -73,23 +73,27 @@ end
#### HOMOGENEOUS LINEAR GAUSSIAN MODEL ####
###########################################

struct HomogeneousLinearGaussianLatentDynamics{T} <: LinearGaussianLatentDynamics{T}
μ0::Vector{T}
Σ0::Matrix{T}
A::Matrix{T}
b::Vector{T}
Q::Matrix{T}
struct HomogeneousLinearGaussianLatentDynamics{
T,V<:AbstractVector{T},M_ge<:AbstractMatrix{T},M_cov<:AbstractMatrix{T}
} <: LinearGaussianLatentDynamics{T}
μ0::V
Σ0::M_cov
A::M_ge
b::V
Q::M_cov
end
calc_μ0(dyn::HomogeneousLinearGaussianLatentDynamics, extra) = dyn.μ0
calc_Σ0(dyn::HomogeneousLinearGaussianLatentDynamics, extra) = dyn.Σ0
calc_A(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer, extra) = dyn.A
calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer, extra) = dyn.b
calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer, extra) = dyn.Q

struct HomogeneousLinearGaussianObservationProcess{T} <: LinearGaussianObservationProcess{T}
H::Matrix{T}
c::Vector{T}
R::Matrix{T}
struct HomogeneousLinearGaussianObservationProcess{
T,V<:AbstractVector{T},M_ge<:AbstractMatrix{T},M_cov<:AbstractMatrix{T}
} <: LinearGaussianObservationProcess{T}
H::M_ge
c::V
R::M_cov
end
calc_H(obs::HomogeneousLinearGaussianObservationProcess, ::Integer, extra) = obs.H
calc_c(obs::HomogeneousLinearGaussianObservationProcess, ::Integer, extra) = obs.c
Expand Down
41 changes: 41 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,47 @@ using TestItemRunner
@test only(states).Σ ≈ Σ_X1
end

@testitem "Square root Kalman filter test" begin
using AnalyticalFilters
using LinearAlgebra
using PDMats
using StableRNGs

rng = StableRNG(1234)
μ0 = rand(rng, 2)
Σ0 = rand(rng, 2, 2)
Σ0 = Σ0 * Σ0' # make Σ0 positive definite
Σ0 = PDMat(Σ0)
A = rand(rng, 2, 2)
b = zeros(2)
Q = rand(rng, 2, 2)
Q = Q * Q' # make Q positive definite
Q = PDMat(Q)
H = rand(rng, 2, 2)
c = zeros(2)
R = rand(rng, 2, 2)
R = R * R' # make R positive definite
R = PDMat(R)

model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R)

T = 5
observations = [rand(rng, 2) for _ in 1:T]
extras = [nothing for _ in 1:T]

kf = KalmanFilter()
kf_states, kf_ll = AnalyticalFilters.filter(model, kf, observations, nothing, extras)

srkf = SRKF()
srkf_states, srkf_ll = AnalyticalFilters.filter(
model, srkf, observations, nothing, extras
)

@test last(kf_states).μ ≈ last(srkf_states).μ
@test last(kf_states).Σ ≈ last(srkf_states).Σ.L * last(srkf_states).Σ.L'
@test kf_ll ≈ srkf_ll
end

@testitem "Forward algorithm test" begin
using AnalyticFilters
using Distributions
Expand Down