Skip to content

Commit

Permalink
add tests for combinations of measurement models
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Dec 12, 2024
1 parent 6838ddf commit 36da609
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/LowLevelParticleFilters.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LowLevelParticleFilters

export KalmanFilter, SqKalmanFilter, UnscentedKalmanFilter, DAEUnscentedKalmanFilter, ExtendedKalmanFilter, ParticleFilter, AuxiliaryParticleFilter, AdvancedParticleFilter, PFstate, index, state, covariance, num_particles, effective_particles, weights, expweights, particles, particletype, smooth, sample_measurement, simulate, loglik, log_likelihood_fun, forward_trajectory, mean_trajectory, mode_trajectory, weighted_mean, weighted_cov, update!, predict!, correct!, reset!, metropolis, shouldresample, TupleProduct
export LinearMeasurementModel, EKFMeasurementModel, UKFMeasurementModel, CompositeMesurementModel
export LinearMeasurementModel, EKFMeasurementModel, UKFMeasurementModel, CompositeMeasurementModel
@deprecate weigthed_mean weighted_mean
@deprecate weigthed_cov weighted_cov

Expand All @@ -26,6 +26,7 @@ abstract type ResamplingStrategy end
struct ResampleSystematic <: ResamplingStrategy end

abstract type AbstractFilter end
abstract type AbstractKalmanFilter <: AbstractFilter end

include("PFtypes.jl")
include("solutions.jl")
Expand Down
2 changes: 2 additions & 0 deletions src/ekf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ function Base.getproperty(ekf::EKF, s::Symbol) where EKF <: AbstractExtendedKalm
mm = getfield(ekf, :measurement_model)
if s fieldnames(typeof(mm))
return getfield(mm, s)
elseif s === :measurement
return measurement(mm)
end
kf = getfield(ekf, :kf)
if s fieldnames(typeof(kf))
Expand Down
2 changes: 0 additions & 2 deletions src/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
abstract type AbstractKalmanFilter <: AbstractFilter end

function convert_cov_type(R1, R)
if !(eltype(R) <: AbstractFloat)
R = float.(R)
Expand Down
91 changes: 86 additions & 5 deletions src/measurement_model.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,74 @@
abstract type AbstractMeasurementModel end

measurement(model::AbstractMeasurementModel) = model.measurement

struct CompositeMeasurementModel{M} <: AbstractMeasurementModel
models::M
ny::Int
R2
end

"""
ComponsiteMeasurementModel{M}
CompositeMeasurementModel(model1, model2, ...)
A composite measurement model that combines multiple measurement models. This model acts as all component models concatenated. The tuple returned from [`correct!`](@ref) will be
- `ll`: The sum of the log-likelihood of all component models
- `e`: The concatenated innovation vector
- `S`: A vector of the innovation covariance matrices of the component models
- `Sᵪ`: A vector of the Cholesky factorizations of the innovation covariance matrices of the component models
- `K`: A vector of the Kalman gains of the component models
A composite measurement model that combines multiple measurement models.
If all sensors operate on at the same rate, and all measurement models are of the same type, it's more efficient to use a single measurement model with a vector-valued measurement function.
# Fields:
- `models`: A tuple of measurement models
"""
struct ComponsiteMeasurementModel{M} <: AbstractMeasurementModel
models::M
function CompositeMeasurementModel(m1, rest...)
models = (m1, rest...)
ny = sum(m.ny for m in models)
R2 = cat([m.R2 for m in models]..., dims=(1,2))
CompositeMeasurementModel(models, ny, R2)
end

isinplace(model::CompositeMeasurementModel) = isinplace(model.models[1])

function measurement(model::CompositeMeasurementModel)
function (x,u,p,t)
y = zeros(model.ny)
i = 1
for m in model.models
y[i:i+m.ny-1] .= measurement(m)(x,u,p,t)
i += m.ny
end
y
end
end

function correct!(
kf::AbstractKalmanFilter,
measurement_model::CompositeMeasurementModel,
u,
y,
p = parameters(kf),
t::Real = index(kf) * kf.Ts;
)
ll = 0.0
e = zeros(measurement_model.ny)
S = []
Sᵪ = []
K = []
last_ind = 0
for i = 1:length(measurement_model.models)
lli, ei, Si, Sᵪi, Ki = correct!(kf, measurement_model.models[i], u, y, p, t)
ll += lli
inds = (1:measurement_model.models[i].ny) .+ last_ind
e[inds] .= ei
last_ind = inds[end]
push!(S, Si)
push!(Sᵪ, Sᵪi)
push!(K, Ki)
end
ll, e, S, Sᵪ, K
end

struct UKFMeasurementModel{IPM,AUGM,MT,RT,IT,MET,CT,CCT,CAT} <: AbstractMeasurementModel
Expand Down Expand Up @@ -156,7 +215,6 @@ function UKFMeasurementModel{T,IPM,AUGM}(
end



struct SigmaPointCache{X0, X1}
x0::X0
x1::X1
Expand Down Expand Up @@ -306,3 +364,26 @@ struct LinearMeasurementModel{CT,DT,RT,CAT} <: AbstractMeasurementModel
end

LinearMeasurementModel(C, D, R2; ny = size(R2, 1), cache = nothing, nx=nothing) = LinearMeasurementModel(C, D, R2, ny, cache)
isinplace(::LinearMeasurementModel) = false

function (model::LinearMeasurementModel)(x,u,p,t)
y = model.C*x
if !iszero(model.D)
if y isa SVector
y += model.D*u
else
mul!(y, model.D, u, 1, 1)
end
end
y
end

function (model::LinearMeasurementModel)(y,x,u,p,t)
mul!(y, model.C, x)
if !iszero(model.D)
mul!(y, model.D, u, 1, 1)
end
y
end

measurement(model::LinearMeasurementModel) = model
2 changes: 2 additions & 0 deletions src/ukf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ function Base.getproperty(ukf::UnscentedKalmanFilter, s::Symbol)
return getfield(mm, s)
elseif s === :nx
return length(getfield(ukf, :x))
elseif s === :measurement
return measurement(mm)
else
throw(ArgumentError("$(typeof(ukf)) has no property named $s"))
end
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ end
include("test_parameters.jl")
end

@testset "measurement_models" begin
@info "Testing measurement_models"
include("test_measurement_models.jl")
end


end
Expand Down
4 changes: 2 additions & 2 deletions test/test_large.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ plot(sol_sqkf, plothy = true, plote = true)
mm_ukf = UKFMeasurementModel{Float64, true, false}(measurement_large_ip, R2; nx, ny)
mm_ekf = EKFMeasurementModel{Float64, true}(measurement_large_ip, R2; nx, ny)
mm_kf = LinearMeasurementModel(__C, 0, R2; nx, ny)
mm = CompositeMeasurementModel(mm_ukf, mm_ekf, mm_kf)


mms = [mm_ukf, mm_ekf, mm_kf]
mms = [mm_ukf, mm_ekf, mm_kf, mm]

for mm in mms
@show nameof(typeof(mm))
Expand Down
78 changes: 78 additions & 0 deletions test/test_measurement_models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using LowLevelParticleFilters
using Test, Random, LinearAlgebra, Statistics, Test
Random.seed!(0)


## KF

nx = 5 # Dimension of state
nu = 2 # Dimension of input
ny = 3 # Dimension of measurements


# Define linenar state-space system
const __A_ = 0.1*randn(nx, nx)
const __B_ = randn(nx, nu)
const __C_ = randn(ny,nx)

dynamics_l(x,u,p,t) = __A_*x .+ __B_*u
measurement_l(x,u,p,t) = __C_*x

R1 = I(nx)
R2 = I(ny)

T = 200 # Number of time steps
kf = KalmanFilter(__A_, __B_, __C_, 0, R1, R2)
skf = SqKalmanFilter(__A_, __B_, __C_, 0, R1, R2)
ukf = UnscentedKalmanFilter(dynamics_l, measurement_l, R1, R2; ny, nu)
ekf = ExtendedKalmanFilter(dynamics_l, measurement_l, R1, R2; nu)

U = [randn(nu) for _ in 1:T]
x,u,y = LowLevelParticleFilters.simulate(kf, U) # Simuate trajectory using the model in the filter

## Test mixing of measurement models ===========================================

mm_ukf = UKFMeasurementModel{Float64, false, false}(measurement_l, R2; nx, ny)
mm_ekf = EKFMeasurementModel{Float64, false}(measurement_l, R2; nx, ny)
mm_kf = LinearMeasurementModel(__C_, 0, R2; nx, ny)
mm = CompositeMeasurementModel(mm_ukf, mm_ekf, mm_kf)


mms = [mm_ukf, mm_ekf, mm_kf, mm]

for mm in mms
@show nameof(typeof(mm))

correct!(kf, mm, u[1], y[1])
correct!(ekf, mm, u[1], y[1])
correct!(ukf, mm, u[1], y[1])

@test kf.x ekf.x ukf.x
@test kf.R ekf.R ukf.R
end


## Filters with measurement models in them
using Plots
for mm in mms
@show nameof(typeof(mm))

ukf = UnscentedKalmanFilter(dynamics_l, mm, R1; ny, nu)
ekf = ExtendedKalmanFilter(dynamics_l, mm, R1; nu)

correct!(ekf, mm, u[1], y[1])
correct!(ukf, mm, u[1], y[1])

@test ekf.x ukf.x
@test ekf.R ukf.R

sol_ukf = forward_trajectory(ukf, u, y)
sol_ekf = forward_trajectory(ekf, u, y)
plot(sol_ukf) # |> display
plot(sol_ekf) # |> display

@test sol_ukf.x sol_ekf.x
@test sol_ukf.xt sol_ekf.xt
@test sol_ukf.R sol_ekf.R
@test sol_ukf.Rt sol_ekf.Rt
end

0 comments on commit 36da609

Please sign in to comment.