Skip to content

Commit

Permalink
add tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Dec 19, 2024
1 parent 9bc2213 commit 1d5cb3f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 16 deletions.
77 changes: 62 additions & 15 deletions src/imm.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
# Interacting multiple models

mutable struct IMM{MT, PT, XT, RT, μT} <: AbstractFilter
mutable struct IMM{MT, PT, XT, RT, μT, PAT} <: AbstractFilter
models::MT
P::PT
x::XT
R::RT
μ::μT
p::PAT
end


"""
IMM(models, P, μ; check = true)
IMM(models, P, μ; check = true, p = NullParameters())
Interacting Multiple Model (IMM) filter. This filter is a combination of multiple Kalman-type filters, each with its own state and covariance. The IMM filter is a probabilistically weighted average of the states and covariances of the individual filters. The weights are determined by the probability matrix `P` and the mixing probabilities `μ`.
!!! warning "Experimental"
This filter is currently considered experimental and the user interface may change in the future without respecting semantic versioning.
In addition to the [`predict!`](@ref) and [`correct!`](@ref) steps, the IMM filter has an [`interact!`](@ref) method that updates the states and covariances of the individual filters based on the mixing probabilities. The [`combine!`](@ref) method combines the states and covariances of the individual filters into a single state and covariance. These four functions are typically called in either of the orders
- `correct!, combine!, interact!, predict!` (as is done in [`update!`](@ref))
- `interact!, predict!, correct!, combine!` (as is done in the reference cited below)
These two orders are cyclic permutations of each other, and the order used in [`update!`](@ref) is chosen to align with the order used in the other filters, where the initial condition is corrected using the first measurement, i.e., we assume the first measurement updates ``x(0|-1)`` to ``x(0|0)``.
The (combined) state and covariance of the IMM filter is made up of the weighted average of the states and covariances of the individual filters. The weights are the initial mixing probabilities `μ`.
Ref: "Interacting multiple model methods in target tracking: a survey", E. Mazor; A. Averbuch; Y. Bar-Shalom; J. Dayan
# Arguments:
- `models`: An array of Kalman-type filters, such as [`KalmanFilter`](@ref), [`ExtendedKalmanFilter`](@ref), [`UnscentedKalmanFilter`](@ref), etc. The state of each model must have the same meaning, such that forming a weighted average makes sense.
- `P`: The mode-transition probability matrix. `P[i,j]` is the probability of transitioning from mode `i` to mode `j` (each row must sum to one).
- `μ`: The initial mixing probabilities. `μ[i]` is the probability of being in mode `i` at the initial contidion (must sum to one).
- `check`: If `true`, check that the inputs are valid. If `false`, skip the checks.
- `p`: Parameters for the filter. NOTE: this `p` is shared among all internal filters. The internal `p` of each filter will be overridden by this one.
"""
function IMM(models, P, μ; check=true)
function IMM(models, P, μ; check=true, p = NullParameters())
if check
N = length(models)
length(μ) == N || throw(ArgumentError("μ must have the same length as the number of models"))
Expand All @@ -36,7 +48,7 @@ function IMM(models, P, μ; check=true)
end
x = sum(i->μ[i]*models[i].x, eachindex(models))
R = sum(i->μ[i]*models[i].R, eachindex(models))
IMM(models, P, x, R, μ)
IMM(models, P, x, R, μ, p)
end

function Base.getproperty(imm::IMM, s::Symbol)
Expand All @@ -49,7 +61,13 @@ function Base.getproperty(imm::IMM, s::Symbol)
end


"""
interact!(imm::IMM)
The interaction step of the IMM filter updates the state and covariance of each internal model based on the mixing probabilities `imm.μ` and the transition probability matrix `imm.P`.
Models with small mixing probabilities will have their states and covariances updated more towards the states and covariances of models with higher mixing probabilities, and vice versa.
"""
function interact!(imm::IMM)
(; μ, P, models) = imm
@assert sum(μ) 1.0
Expand All @@ -64,7 +82,7 @@ function interact!(imm::IMM)
for i = eachindex(models)
μij = P[i,j] * μ[i] / cj[j]
d = models[i].x - new_x[j]
@bangbang new_R[j] .+= μij .* (d * d' .+ models[i].R)
@bangbang new_R[j] .+= symmetrize(μij .* (d * d' .+ models[i].R))
end
end
for (model, x, R) in zip(models, new_x, new_R)
Expand All @@ -81,35 +99,49 @@ function predict!(imm::IMM, args...; kwargs...)
end
end

function correct!(imm::IMM, u, y, t, args...; kwargs...)


"""
ll, lls, rest = correct!(imm::IMM, u, y, args; kwargs)
The correct step of the IMM filter corrects each model with the measurements `y` and control input `u`. The mixing probabilities `imm.μ` are updated based on the likelihood of each model given the measurements and the transition probability matrix `P`.
The returned tuple consists of the sum of the log-likelihood of all models, the vector of individual log-likelihoods and an array of the rest of the return values from the correct step of each model.
"""
function correct!(imm::IMM, u, y, args...; kwargs...)
(; μ, P, models) = imm
lls = zeros(eltype(imm.x), length(models))
rest = []
for (j, model) in enumerate(models)
lls[j], others... = correct!(model, u, y, args...; kwargs...)
lls[j], others... = correct!(model, u, y; kwargs...)
push!(rest, others)
end
μP = P'μ # TODO: verify order we want for P
μP = P'μ
new_μ = exp.(lls) .* μP
μ .= new_μ ./ sum(new_μ)

sum(lls), rest
sum(lls), lls, rest
end

"""
combine!(imm::IMM)
Combine the models of the IMM filter into a single state `imm.x` and covariance `imm.R`. This is done by taking a weighted average of the states and covariances of the individual models, where the weights are the mixing probabilities `μ`.
"""
function combine!(imm::IMM)
(; μ, x, R, models) = imm
@assert sum(μ) 1.0

x = 0*x
R = 0*R
@bangbang x .= 0 .* x
@bangbang R .= 0 .* R

for (j, model) in enumerate(models)
@bangbang x .+= μ[j] .* model.x
end

for (j, model) in enumerate(models)
d = model.x .- x
@bangbang R .+= μ[j] .* (model.R .+ d * d')
@bangbang R .+= symmetrize(μ[j] .* (model.R .+ d * d'))
end

imm.x = x
Expand All @@ -118,11 +150,26 @@ function combine!(imm::IMM)
end


function update!(imm::IMM, args...; kwargs...)
ll, rest = correct!(imm, args...; kwargs...)
"""
update!(imm::IMM, u, y, p, t; correct_kwargs = (;), predict_kwargs = (;))
The combined udpate for an [`IMM`](@ref) filter performs the following steps:
1. Correct each model with the measurements `y` and control input `u`.
2. Combine the models into a single state and covariance.
3. Interact the models to update their respective state and covariance.
4. Predict each model to the next time step.
This differs slightly from the udpate step of other filters, where at the end of an update the state of the filter is the one-step ahead _predicted_ value, whereas here each individual filter has a predicted state, but the [`combine!`](@ref) step of the IMM filter hasn't been performed on the predictions yet. The state of the IMM filter is thus ``x(t|t)`` and not ``x(t+1|t)`` like it is for other filters, and each filter internal to the IMM.
# Arguments:
- `correct_kwargs`: An optional named tuple of keyword arguments that are sent to [`correct!`](@ref).
- `predict_kwargs`: An optional named tuple of keyword arguments that are sent to [`predict!`](@ref).
"""
function update!(imm::IMM, args...; correct_kwargs = (;), predict_kwargs = (;))
ll, rest = correct!(imm, args...; correct_kwargs...)
combine!(imm)
interact!(imm)
predict!(imm, args...; kwargs...)
predict!(imm, args...; predict_kwargs...)
ll, rest
end

Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,11 @@ end
include("test_ekf.jl")
end

@testset "imm" begin
@info "Testing imm"
include("test_imm.jl")
end

@testset "parameters" begin
@info "Testing parameters"
include("test_parameters.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/test_imm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ x,u,y = simulate(imm, T, du) # Simulate the IMM

sol = forward_trajectory(imm, u, y) # Forward trajectory

plot(sol)
plot(sol, plotx = true, plotxt=true, plotu=true, ploty=true, plotyh=true, plotyht=true, plote=true, plotR=true, plotRt=true)

0 comments on commit 1d5cb3f

Please sign in to comment.