Skip to content

Commit

Permalink
add linear measurement model and tests for mixing
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Dec 11, 2024
1 parent 03e419a commit f25f92c
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/LowLevelParticleFilters.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LowLevelParticleFilters

export KalmanFilter, SqKalmanFilter, UnscentedKalmanFilter, DAEUnscentedKalmanFilter, ExtendedKalmanFilter, ParticleFilter, AuxiliaryParticleFilter, AdvancedParticleFilter, PFstate, index, state, covariance, num_particles, effective_particles, weights, expweights, particles, particletype, smooth, sample_measurement, simulate, loglik, log_likelihood_fun, forward_trajectory, mean_trajectory, mode_trajectory, weighted_mean, weighted_cov, update!, predict!, correct!, reset!, metropolis, shouldresample, TupleProduct
export LinearMeasurementModel, EKFMeasurementModel, UKFMeasurementModel, CompositeMesurementModel
@deprecate weigthed_mean weighted_mean
@deprecate weigthed_cov weighted_cov

Expand Down
4 changes: 2 additions & 2 deletions src/ekf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ function correct!(ukf::AbstractExtendedKalmanFilter, u, y, p, t::Real; kwargs...
correct!(ukf, measurement_model, u, y, p, t::Real; kwargs...)
end

function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, measurement_model::EKFMeasurementModel, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.measurement_model.R2, kf.x, u, p, t))where IPM
@unpack x,R = kf
function correct!(kf::AbstractKalmanFilter, measurement_model::EKFMeasurementModel{IPM}, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(measurement_model.R2, kf.x, u, p, t)) where IPM
(; x,R) = kf
(; measurement, Cjac) = measurement_model
C = Cjac(x, u, p, t)
if IPM
Expand Down
9 changes: 7 additions & 2 deletions src/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ The correct step for a Kalman filter returns not only the log likelihood `ll` an
If `R2` stored in `kf` is a function `R2(x, u, p, t)`, this function is evaluated at the state *before* the correction is performed.
The measurement noise covariance matrix `R2` stored in the filter object can optionally be overridden by passing the argument `R2`, in this case `R2` must be a matrix.
"""
function correct!(kf::AbstractKalmanFilter, u, y, p=parameters(kf), t::Real = index(kf)*kf.Ts; R2 = get_mat(kf.R2, kf.x, u, p, t), Ct = get_mat(kf.C, kf.x, u, p, t), Dt = get_mat(kf.D, kf.x, u, p, t))
@unpack x,R = kf
function correct!(kf::AbstractKalmanFilter, mm::LinearMeasurementModel, u, y, p=parameters(kf), t::Real = index(kf)*kf.Ts; R2 = get_mat(mm.R2, kf.x, u, p, t), Ct = get_mat(mm.C, kf.x, u, p, t), Dt = get_mat(mm.D, kf.x, u, p, t))
(;x,R) = kf
e = y .- Ct*x
if !iszero(Dt)
e -= Dt*u
Expand All @@ -103,6 +103,11 @@ function correct!(kf::AbstractKalmanFilter, u, y, p=parameters(kf), t::Real = in
(; ll, e, S, Sᵪ, K)
end

function correct!(kf::AbstractKalmanFilter, u, y, p=parameters(kf), t::Real = index(kf)*kf.Ts; kwargs...)
measurement_model = LinearMeasurementModel(kf.C, kf.D, kf.R2, length(y), nothing)
correct!(kf, measurement_model, u, y, p, t; kwargs...)
end

"""
predict!(f, u, p = parameters(f), t = index(f))
Expand Down
83 changes: 82 additions & 1 deletion src/measurement_model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
abstract type AbstractMeasurementModel end

"""
ComponsiteMeasurementModel{M}
A composite measurement model that combines multiple measurement models.
# Fields:
- `models`: A tuple of measurement models
"""
struct ComponsiteMeasurementModel{M} <: AbstractMeasurementModel
models::M
end
Expand All @@ -18,6 +26,21 @@ end

isinplace(::UKFMeasurementModel{IPM}) where IPM = IPM

"""
UKFMeasurementModel{inplace_measurement,augmented_measurement}(measurement, R2, ny, ne, innovation, mean, cov, cross_cov, cache = nothing)
A measurement model for the Unscented Kalman Filter.
# Arguments:
- `measurement`: The measurement function `y = h(x, u, p, t)`
- `R2`: The measurement noise covariance matrix
- `ny`: The number of measurement variables
- `ne`: If `augmented_measurement` is `true`, the number of measurement noise variables
- `innovation`: The innovation function `innovation(y, yh) -> e`
- `mean(ys::AbstractVector{<:AbstractVector})`: computes the mean of the vector of vectors of output sigma points.
- `cov(ys::AbstractVector{<:AbstractVector}, y::AbstractVector)`: computes the covariance matrix of the output sigma points.
- `cross_cov(xs::AbstractVector{<:AbstractVector}, x::AbstractVector, ys::AbstractVector{<:AbstractVector}, y::AbstractVector)` where the arguments represents (state sigma points, mean state, output sigma points, mean output). The function should return the **cross-covariance** matrix between the state and output sigma points.
"""
UKFMeasurementModel{IPM,AUGM}(
measurement,
R2,
Expand Down Expand Up @@ -65,7 +88,13 @@ function add_cache(model::UKFMeasurementModel{IPM,AUGM}, cache) where {IPM,AUGM}
)
end

"""
UKFMeasurementModel{T,IPM,AUGM}(measurement, R2; nx, ny, ne = nothing, innovation = -, mean = safe_mean, cov = safe_cov, cross_cov = cross_cov, static = nothing)
- `T` is the element type used for arrays
- `IPM` is a boolean indicating if the measurement function is inplace
- `AUGM` is a boolean indicating if the measurement model is augmented
"""
function UKFMeasurementModel{T,IPM,AUGM}(
measurement,
R2;
Expand Down Expand Up @@ -133,6 +162,17 @@ struct SigmaPointCache{X0, X1}
x1::X1
end

"""
SigmaPointCache(nx, nw, ny, L, static)
# Arguments:
- `nx`: Number of state variables
- `nw`: Number of process noise variables for augmented dynamics. If not using augmented dynamics, set to 0.
- `ny`: Number of transformed sigma points
- `L`: Number of sigma points
- `static`: If `true`, the cache will use static arrays for the sigma points. This can be faster for small systems.
"""
function SigmaPointCache{T}(nx, nw, ny, L, static) where T
if static
x0 = [@SVector zeros(T, nx + nw) for _ = 1:2L+1]
Expand All @@ -159,6 +199,18 @@ end

isinplace(::EKFMeasurementModel{IPM}) where IPM = IPM

"""
EKFMeasurementModel{IPM}(measurement, R2, ny, Cjac, cache = nothing)
A measurement model for the Extended Kalman Filter.
# Arguments:
- `IPM`: A boolean indicating if the measurement function is inplace
- `measurement`: The measurement function `y = h(x, u, p, t)`
- `R2`: The measurement noise covariance matrix
- `ny`: The number of measurement variables
- `Cjac`: The Jacobian of the measurement function `Cjac(x, u, p, t)`. If none is provided, ForwardDiff will be used.
"""
EKFMeasurementModel{IPM}(
measurement,
R2,
Expand Down Expand Up @@ -190,7 +242,12 @@ function add_cache(model::EKFMeasurementModel{IPM}, cache) where {IPM}
)
end

"""
EKFMeasurementModel{T,IPM}(measurement::M, R2; nx, ny, Cjac = nothing)
- `T` is the element type used for arrays
- `IPM` is a boolean indicating if the measurement function is inplace
"""
function EKFMeasurementModel{T,IPM}(
measurement::M,
R2;
Expand Down Expand Up @@ -224,4 +281,28 @@ function EKFMeasurementModel{T,IPM}(
Cjac,
nothing,
)
end
end


## Linear measurement model ====================================================

"""
LinearMeasurementModel{CT, DT, RT, CAT}
A linear measurement model ``y = C*x + D*u + e``.
# Fields:
- `C`
- `D`
- `R2`: The measurement noise covariance matrix
- `ny`: The number of measurement variables
"""
struct LinearMeasurementModel{CT,DT,RT,CAT} <: AbstractMeasurementModel
C::CT
D::DT
R2::RT
ny::Int
cache::CAT
end

LinearMeasurementModel(C, D, R2; ny = size(R2, 1), cache = nothing, nx=nothing) = LinearMeasurementModel(C, D, R2, ny, cache)
49 changes: 26 additions & 23 deletions src/ukf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,28 +336,28 @@ end


function correct!(
ukf::UnscentedKalmanFilter{IPD,IPM,AUGD,AUGM},
kf::AbstractKalmanFilter,
measurement_model::UKFMeasurementModel,
u,
y,
p = parameters(ukf),
t::Real = index(ukf) * ukf.Ts;
R2 = get_mat(measurement_model.R2, ukf.x, u, p, t),
p = parameters(kf),
t::Real = index(kf) * kf.Ts;
R2 = get_mat(measurement_model.R2, kf.x, u, p, t),
mean = measurement_model.mean,
measurement_cov = measurement_model.cross_cov,
innovation = measurement_model.innovation,
measurement = measurement_model.measurement,
) where {IPD,IPM,AUGD,AUGM}
)

sigma_point_cache = measurement_model.cache
xsm = sigma_point_cache.x0
ys = sigma_point_cache.x1
(; x, R) = ukf
(; x, R) = kf

T = promote_type(eltype(x), eltype(R), eltype(R2))
ns = length(xsm)
sigmapoints_c!(ukf, sigma_point_cache, R2) # TODO: should this take other arguments?
propagate_sigmapoints_c!(ukf, u, p, t, R2, measurement_model)
sigmapoints_c!(kf, measurement_model, R2) # TODO: should this take other arguments?
propagate_sigmapoints_c!(kf, u, p, t, R2, measurement_model)
ym = mean(ys)
C = measurement_cov(xsm, x, ys, ym)
e = innovation(y, ym)
Expand All @@ -366,28 +366,31 @@ function correct!(
issuccess(Sᵪ) ||
error("Cholesky factorization of innovation covariance failed, got S = ", S)
K = (C ./ (ns - 1)) / Sᵪ # ns normalization to make it a covariance matrix
ukf.x += K * e
kf.x += K * e
# mul!(x, K, e, 1, 1) # K and e will be SVectors if ukf correctly initialized
RmKSKT!(ukf, K, S)
RmKSKT!(kf, K, S)
ll = extended_logpdf(SimpleMvNormal(PDMat(S, Sᵪ)), e) #- 1/2*logdet(S) # logdet is included in logpdf
(; ll, e, S, Sᵪ, K)
end

# AUGM = false
function sigmapoints_c!(
ukf::UnscentedKalmanFilter{<:Any,<:Any,<:Any,false},
sigma_point_cache,
kf,
measurement_model::UKFMeasurementModel{<:Any,false},
R2,
)
sigma_point_cache = measurement_model.cache
xsm = sigma_point_cache.x0
sigmapoints!(xsm, eltype(xsm)(ukf.x), ukf.R)
sigmapoints!(xsm, eltype(xsm)(kf.x), kf.R)
end

function sigmapoints_c!(
ukf::UnscentedKalmanFilter{<:Any,<:Any,<:Any,true},
sigma_point_cache,
kf,
measurement_model::UKFMeasurementModel{<:Any,true},
R2,
)
(; x, R) = ukf
(; x, R) = kf
sigma_point_cache = measurement_model.cache
xsm = sigma_point_cache.x0
nx = length(x)
nv = size(R2, 1)
Expand All @@ -398,12 +401,12 @@ end

# IPM = true
function propagate_sigmapoints_c!(
ukf::UnscentedKalmanFilter{<:Any,true,<:Any},
kf,
u,
p,
t,
R2,
measurement_model,
measurement_model::UKFMeasurementModel{true},
)
sigma_point_cache = measurement_model.cache
xsm = sigma_point_cache.x0
Expand All @@ -415,17 +418,17 @@ end

# AUGM = true
function propagate_sigmapoints_c!(
ukf::UnscentedKalmanFilter{<:Any,false,<:Any,true},
kf,
u,
p,
t,
R2,
measurement_model,
measurement_model::UKFMeasurementModel{false,true},
)
sigma_point_cache = measurement_model.cache
xsm = sigma_point_cache.x0
ys = sigma_point_cache.x1
(; x, R) = ukf
(; x, R) = kf
nx = length(x)
nv = size(R2, 1)
xinds = 1:nx
Expand All @@ -437,12 +440,12 @@ end

# AUGM = false
function propagate_sigmapoints_c!(
ukf::UnscentedKalmanFilter{<:Any,false,<:Any,false},
kf,
u,
p,
t,
R2,
measurement_model,
measurement_model::UKFMeasurementModel{false,false},
)
sigma_point_cache = measurement_model.cache
xsm = sigma_point_cache.x0
Expand Down
37 changes: 30 additions & 7 deletions test/test_large.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ const __C = randn(ny,nx)
dynamics_large(x,u,p,t) = __A*x .+ __B*u
measurement_large(x,u,p,t) = __C*x

R1 = I(nx)
R2 = I(ny)

T = 200 # Number of time steps
kf = KalmanFilter(__A, __B, __C, 0, I(nx), I(ny))
skf = SqKalmanFilter(__A, __B, __C, 0, I(nx), I(ny))
ukf = UnscentedKalmanFilter(dynamics_large, measurement_large, I(nx), I(ny); ny, nu)
ekf = ExtendedKalmanFilter(dynamics_large, measurement_large, I(nx), I(ny); nu)
kf = KalmanFilter(__A, __B, __C, 0, R1, R2)
skf = SqKalmanFilter(__A, __B, __C, 0, R1, R2)
ukf = UnscentedKalmanFilter(dynamics_large, measurement_large, R1, R2; ny, nu)
ekf = ExtendedKalmanFilter(dynamics_large, measurement_large, R1, R2; nu)

U = [randn(nu) for _ in 1:T]
x,u,y = LowLevelParticleFilters.simulate(kf, U) # Simuate trajectory using the model in the filter
Expand Down Expand Up @@ -76,8 +79,8 @@ function measurement_large_ip(y,x,u,p,t)
nothing
end

ukf = UnscentedKalmanFilter(dynamics_large_ip, measurement_large_ip, I(nx), I(ny); ny, nu)
ekf = ExtendedKalmanFilter(dynamics_large_ip, measurement_large_ip, I(nx), I(ny); nu)
ukf = UnscentedKalmanFilter(dynamics_large_ip, measurement_large_ip, R1, R2; ny, nu)
ekf = ExtendedKalmanFilter(dynamics_large_ip, measurement_large_ip, R1, R2; nu)

sol_ukf = forward_trajectory(ukf, u, y)
a = @allocations forward_trajectory(ukf, u, y)
Expand All @@ -100,4 +103,24 @@ using Plots
plot(sol_kf, plothy = true, plote = true)
plot(sol_ukf, plothy = true, plote = true, plotR=true)
plot(sol_ekf, plothy = true, plote = true, plotRt=true)
plot(sol_sqkf, plothy = true, plote = true)
plot(sol_sqkf, plothy = true, plote = true)

## Test mixing of measurement models ===========================================

mm_ukf = UKFMeasurementModel{Float64, true, false}(measurement_large_ip, R2; nx, ny)
mm_ekf = EKFMeasurementModel{Float64, true}(measurement_large_ip, R2; nx, ny)
mm_kf = LinearMeasurementModel(__C, 0, R2; nx, ny)


mms = [mm_ukf, mm_ekf, mm_kf]

for mm in mms
@show nameof(typeof(mm))

correct!(kf, mm, u[1], y[1])
correct!(ekf, mm, u[1], y[1])
correct!(ukf, mm, u[1], y[1])

@test kf.x ekf.x ukf.x
@test kf.R ekf.R ukf.R
end

0 comments on commit f25f92c

Please sign in to comment.