From 1d5cb3f6d08d705d5e9a0efc2a1b2452fa2243f0 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Thu, 19 Dec 2024 10:09:53 +0100 Subject: [PATCH] add tests and docs --- src/imm.jl | 77 ++++++++++++++++++++++++++++++++++++++---------- test/runtests.jl | 5 ++++ test/test_imm.jl | 2 +- 3 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/imm.jl b/src/imm.jl index 32c712b..4611aec 100644 --- a/src/imm.jl +++ b/src/imm.jl @@ -1,29 +1,41 @@ # Interacting multiple models -mutable struct IMM{MT, PT, XT, RT, μT} <: AbstractFilter +mutable struct IMM{MT, PT, XT, RT, μT, PAT} <: AbstractFilter models::MT P::PT x::XT R::RT μ::μT + p::PAT end """ - IMM(models, P, μ; check = true) + IMM(models, P, μ; check = true, p = NullParameters()) Interacting Multiple Model (IMM) filter. This filter is a combination of multiple Kalman-type filters, each with its own state and covariance. The IMM filter is a probabilistically weighted average of the states and covariances of the individual filters. The weights are determined by the probability matrix `P` and the mixing probabilities `μ`. !!! warning "Experimental" This filter is currently considered experimental and the user interface may change in the future without respecting semantic versioning. +In addition to the [`predict!`](@ref) and [`correct!`](@ref) steps, the IMM filter has an [`interact!`](@ref) method that updates the states and covariances of the individual filters based on the mixing probabilities. The [`combine!`](@ref) method combines the states and covariances of the individual filters into a single state and covariance. These four functions are typically called in either of the orders +- `correct!, combine!, interact!, predict!` (as is done in [`update!`](@ref)) +- `interact!, predict!, correct!, combine!` (as is done in the reference cited below) + +These two orders are cyclic permutations of each other, and the order used in [`update!`](@ref) is chosen to align with the order used in the other filters, where the initial condition is corrected using the first measurement, i.e., we assume the first measurement updates ``x(0|-1)`` to ``x(0|0)``. + +The (combined) state and covariance of the IMM filter is made up of the weighted average of the states and covariances of the individual filters. The weights are the initial mixing probabilities `μ`. + +Ref: "Interacting multiple model methods in target tracking: a survey", E. Mazor; A. Averbuch; Y. Bar-Shalom; J. Dayan + # Arguments: - `models`: An array of Kalman-type filters, such as [`KalmanFilter`](@ref), [`ExtendedKalmanFilter`](@ref), [`UnscentedKalmanFilter`](@ref), etc. The state of each model must have the same meaning, such that forming a weighted average makes sense. - `P`: The mode-transition probability matrix. `P[i,j]` is the probability of transitioning from mode `i` to mode `j` (each row must sum to one). - `μ`: The initial mixing probabilities. `μ[i]` is the probability of being in mode `i` at the initial contidion (must sum to one). - `check`: If `true`, check that the inputs are valid. If `false`, skip the checks. +- `p`: Parameters for the filter. NOTE: this `p` is shared among all internal filters. The internal `p` of each filter will be overridden by this one. """ -function IMM(models, P, μ; check=true) +function IMM(models, P, μ; check=true, p = NullParameters()) if check N = length(models) length(μ) == N || throw(ArgumentError("μ must have the same length as the number of models")) @@ -36,7 +48,7 @@ function IMM(models, P, μ; check=true) end x = sum(i->μ[i]*models[i].x, eachindex(models)) R = sum(i->μ[i]*models[i].R, eachindex(models)) - IMM(models, P, x, R, μ) + IMM(models, P, x, R, μ, p) end function Base.getproperty(imm::IMM, s::Symbol) @@ -49,7 +61,13 @@ function Base.getproperty(imm::IMM, s::Symbol) end +""" + interact!(imm::IMM) + +The interaction step of the IMM filter updates the state and covariance of each internal model based on the mixing probabilities `imm.μ` and the transition probability matrix `imm.P`. +Models with small mixing probabilities will have their states and covariances updated more towards the states and covariances of models with higher mixing probabilities, and vice versa. +""" function interact!(imm::IMM) (; μ, P, models) = imm @assert sum(μ) ≈ 1.0 @@ -64,7 +82,7 @@ function interact!(imm::IMM) for i = eachindex(models) μij = P[i,j] * μ[i] / cj[j] d = models[i].x - new_x[j] - @bangbang new_R[j] .+= μij .* (d * d' .+ models[i].R) + @bangbang new_R[j] .+= symmetrize(μij .* (d * d' .+ models[i].R)) end end for (model, x, R) in zip(models, new_x, new_R) @@ -81,27 +99,41 @@ function predict!(imm::IMM, args...; kwargs...) end end -function correct!(imm::IMM, u, y, t, args...; kwargs...) + + +""" + ll, lls, rest = correct!(imm::IMM, u, y, args; kwargs) + +The correct step of the IMM filter corrects each model with the measurements `y` and control input `u`. The mixing probabilities `imm.μ` are updated based on the likelihood of each model given the measurements and the transition probability matrix `P`. + +The returned tuple consists of the sum of the log-likelihood of all models, the vector of individual log-likelihoods and an array of the rest of the return values from the correct step of each model. +""" +function correct!(imm::IMM, u, y, args...; kwargs...) (; μ, P, models) = imm lls = zeros(eltype(imm.x), length(models)) rest = [] for (j, model) in enumerate(models) - lls[j], others... = correct!(model, u, y, args...; kwargs...) + lls[j], others... = correct!(model, u, y; kwargs...) push!(rest, others) end - μP = P'μ # TODO: verify order we want for P + μP = P'μ new_μ = exp.(lls) .* μP μ .= new_μ ./ sum(new_μ) - sum(lls), rest + sum(lls), lls, rest end +""" + combine!(imm::IMM) + +Combine the models of the IMM filter into a single state `imm.x` and covariance `imm.R`. This is done by taking a weighted average of the states and covariances of the individual models, where the weights are the mixing probabilities `μ`. +""" function combine!(imm::IMM) (; μ, x, R, models) = imm @assert sum(μ) ≈ 1.0 - x = 0*x - R = 0*R + @bangbang x .= 0 .* x + @bangbang R .= 0 .* R for (j, model) in enumerate(models) @bangbang x .+= μ[j] .* model.x @@ -109,7 +141,7 @@ function combine!(imm::IMM) for (j, model) in enumerate(models) d = model.x .- x - @bangbang R .+= μ[j] .* (model.R .+ d * d') + @bangbang R .+= symmetrize(μ[j] .* (model.R .+ d * d')) end imm.x = x @@ -118,11 +150,26 @@ function combine!(imm::IMM) end -function update!(imm::IMM, args...; kwargs...) - ll, rest = correct!(imm, args...; kwargs...) +""" + update!(imm::IMM, u, y, p, t; correct_kwargs = (;), predict_kwargs = (;)) + +The combined udpate for an [`IMM`](@ref) filter performs the following steps: +1. Correct each model with the measurements `y` and control input `u`. +2. Combine the models into a single state and covariance. +3. Interact the models to update their respective state and covariance. +4. Predict each model to the next time step. + +This differs slightly from the udpate step of other filters, where at the end of an update the state of the filter is the one-step ahead _predicted_ value, whereas here each individual filter has a predicted state, but the [`combine!`](@ref) step of the IMM filter hasn't been performed on the predictions yet. The state of the IMM filter is thus ``x(t|t)`` and not ``x(t+1|t)`` like it is for other filters, and each filter internal to the IMM. + +# Arguments: +- `correct_kwargs`: An optional named tuple of keyword arguments that are sent to [`correct!`](@ref). +- `predict_kwargs`: An optional named tuple of keyword arguments that are sent to [`predict!`](@ref). +""" +function update!(imm::IMM, args...; correct_kwargs = (;), predict_kwargs = (;)) + ll, rest = correct!(imm, args...; correct_kwargs...) combine!(imm) interact!(imm) - predict!(imm, args...; kwargs...) + predict!(imm, args...; predict_kwargs...) ll, rest end diff --git a/test/runtests.jl b/test/runtests.jl index 6da73d3..62dbddf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -374,6 +374,11 @@ end include("test_ekf.jl") end +@testset "imm" begin + @info "Testing imm" + include("test_imm.jl") +end + @testset "parameters" begin @info "Testing parameters" include("test_parameters.jl") diff --git a/test/test_imm.jl b/test/test_imm.jl index c250e43..42445f5 100644 --- a/test/test_imm.jl +++ b/test/test_imm.jl @@ -122,4 +122,4 @@ x,u,y = simulate(imm, T, du) # Simulate the IMM sol = forward_trajectory(imm, u, y) # Forward trajectory -plot(sol) \ No newline at end of file +plot(sol, plotx = true, plotxt=true, plotu=true, ploty=true, plotyh=true, plotyht=true, plote=true, plotR=true, plotRt=true) \ No newline at end of file