diff --git a/src/LowLevelParticleFilters.jl b/src/LowLevelParticleFilters.jl index f0df567..e54b129 100644 --- a/src/LowLevelParticleFilters.jl +++ b/src/LowLevelParticleFilters.jl @@ -1,7 +1,7 @@ module LowLevelParticleFilters export KalmanFilter, SqKalmanFilter, UnscentedKalmanFilter, DAEUnscentedKalmanFilter, ExtendedKalmanFilter, ParticleFilter, AuxiliaryParticleFilter, AdvancedParticleFilter, PFstate, index, state, covariance, num_particles, effective_particles, weights, expweights, particles, particletype, smooth, sample_measurement, simulate, loglik, log_likelihood_fun, forward_trajectory, mean_trajectory, mode_trajectory, weighted_mean, weighted_cov, update!, predict!, correct!, reset!, metropolis, shouldresample, TupleProduct -export LinearMeasurementModel, EKFMeasurementModel, UKFMeasurementModel, CompositeMesurementModel +export LinearMeasurementModel, EKFMeasurementModel, UKFMeasurementModel, CompositeMeasurementModel @deprecate weigthed_mean weighted_mean @deprecate weigthed_cov weighted_cov @@ -26,6 +26,7 @@ abstract type ResamplingStrategy end struct ResampleSystematic <: ResamplingStrategy end abstract type AbstractFilter end +abstract type AbstractKalmanFilter <: AbstractFilter end include("PFtypes.jl") include("solutions.jl") diff --git a/src/ekf.jl b/src/ekf.jl index 46f919e..89fa1f6 100644 --- a/src/ekf.jl +++ b/src/ekf.jl @@ -96,6 +96,8 @@ function Base.getproperty(ekf::EKF, s::Symbol) where EKF <: AbstractExtendedKalm mm = getfield(ekf, :measurement_model) if s ∈ fieldnames(typeof(mm)) return getfield(mm, s) + elseif s === :measurement + return measurement(mm) end kf = getfield(ekf, :kf) if s ∈ fieldnames(typeof(kf)) diff --git a/src/kalman.jl b/src/kalman.jl index 376724c..73459ae 100644 --- a/src/kalman.jl +++ b/src/kalman.jl @@ -1,5 +1,3 @@ -abstract type AbstractKalmanFilter <: AbstractFilter end - function convert_cov_type(R1, R) if !(eltype(R) <: AbstractFloat) R = float.(R) diff --git a/src/measurement_model.jl b/src/measurement_model.jl index 1d6bc3b..9c55a91 100644 --- a/src/measurement_model.jl +++ b/src/measurement_model.jl @@ -1,15 +1,74 @@ abstract type AbstractMeasurementModel end +measurement(model::AbstractMeasurementModel) = model.measurement + +struct CompositeMeasurementModel{M} <: AbstractMeasurementModel + models::M + ny::Int + R2 +end + """ - ComponsiteMeasurementModel{M} + CompositeMeasurementModel(model1, model2, ...) + +A composite measurement model that combines multiple measurement models. This model acts as all component models concatenated. The tuple returned from [`correct!`](@ref) will be +- `ll`: The sum of the log-likelihood of all component models +- `e`: The concatenated innovation vector +- `S`: A vector of the innovation covariance matrices of the component models +- `Sᵪ`: A vector of the Cholesky factorizations of the innovation covariance matrices of the component models +- `K`: A vector of the Kalman gains of the component models -A composite measurement model that combines multiple measurement models. +If all sensors operate on at the same rate, and all measurement models are of the same type, it's more efficient to use a single measurement model with a vector-valued measurement function. # Fields: - `models`: A tuple of measurement models """ -struct ComponsiteMeasurementModel{M} <: AbstractMeasurementModel - models::M +function CompositeMeasurementModel(m1, rest...) + models = (m1, rest...) + ny = sum(m.ny for m in models) + R2 = cat([m.R2 for m in models]..., dims=(1,2)) + CompositeMeasurementModel(models, ny, R2) +end + +isinplace(model::CompositeMeasurementModel) = isinplace(model.models[1]) + +function measurement(model::CompositeMeasurementModel) + function (x,u,p,t) + y = zeros(model.ny) + i = 1 + for m in model.models + y[i:i+m.ny-1] .= measurement(m)(x,u,p,t) + i += m.ny + end + y + end +end + +function correct!( + kf::AbstractKalmanFilter, + measurement_model::CompositeMeasurementModel, + u, + y, + p = parameters(kf), + t::Real = index(kf) * kf.Ts; +) + ll = 0.0 + e = zeros(measurement_model.ny) + S = [] + Sᵪ = [] + K = [] + last_ind = 0 + for i = 1:length(measurement_model.models) + lli, ei, Si, Sᵪi, Ki = correct!(kf, measurement_model.models[i], u, y, p, t) + ll += lli + inds = (1:measurement_model.models[i].ny) .+ last_ind + e[inds] .= ei + last_ind = inds[end] + push!(S, Si) + push!(Sᵪ, Sᵪi) + push!(K, Ki) + end + ll, e, S, Sᵪ, K end struct UKFMeasurementModel{IPM,AUGM,MT,RT,IT,MET,CT,CCT,CAT} <: AbstractMeasurementModel @@ -156,7 +215,6 @@ function UKFMeasurementModel{T,IPM,AUGM}( end - struct SigmaPointCache{X0, X1} x0::X0 x1::X1 @@ -306,3 +364,26 @@ struct LinearMeasurementModel{CT,DT,RT,CAT} <: AbstractMeasurementModel end LinearMeasurementModel(C, D, R2; ny = size(R2, 1), cache = nothing, nx=nothing) = LinearMeasurementModel(C, D, R2, ny, cache) +isinplace(::LinearMeasurementModel) = false + +function (model::LinearMeasurementModel)(x,u,p,t) + y = model.C*x + if !iszero(model.D) + if y isa SVector + y += model.D*u + else + mul!(y, model.D, u, 1, 1) + end + end + y +end + +function (model::LinearMeasurementModel)(y,x,u,p,t) + mul!(y, model.C, x) + if !iszero(model.D) + mul!(y, model.D, u, 1, 1) + end + y +end + +measurement(model::LinearMeasurementModel) = model \ No newline at end of file diff --git a/src/ukf.jl b/src/ukf.jl index d838501..15c49e4 100644 --- a/src/ukf.jl +++ b/src/ukf.jl @@ -57,6 +57,8 @@ function Base.getproperty(ukf::UnscentedKalmanFilter, s::Symbol) return getfield(mm, s) elseif s === :nx return length(getfield(ukf, :x)) + elseif s === :measurement + return measurement(mm) else throw(ArgumentError("$(typeof(ukf)) has no property named $s")) end diff --git a/test/runtests.jl b/test/runtests.jl index 54fabf6..6da73d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -379,6 +379,10 @@ end include("test_parameters.jl") end +@testset "measurement_models" begin + @info "Testing measurement_models" + include("test_measurement_models.jl") +end end diff --git a/test/test_large.jl b/test/test_large.jl index 859a24c..3702eaf 100644 --- a/test/test_large.jl +++ b/test/test_large.jl @@ -110,9 +110,9 @@ plot(sol_sqkf, plothy = true, plote = true) mm_ukf = UKFMeasurementModel{Float64, true, false}(measurement_large_ip, R2; nx, ny) mm_ekf = EKFMeasurementModel{Float64, true}(measurement_large_ip, R2; nx, ny) mm_kf = LinearMeasurementModel(__C, 0, R2; nx, ny) +mm = CompositeMeasurementModel(mm_ukf, mm_ekf, mm_kf) - -mms = [mm_ukf, mm_ekf, mm_kf] +mms = [mm_ukf, mm_ekf, mm_kf, mm] for mm in mms @show nameof(typeof(mm)) diff --git a/test/test_measurement_models.jl b/test/test_measurement_models.jl new file mode 100644 index 0000000..d5d7937 --- /dev/null +++ b/test/test_measurement_models.jl @@ -0,0 +1,78 @@ +using LowLevelParticleFilters +using Test, Random, LinearAlgebra, Statistics, Test +Random.seed!(0) + + +## KF + +nx = 5 # Dimension of state +nu = 2 # Dimension of input +ny = 3 # Dimension of measurements + + +# Define linenar state-space system +const __A_ = 0.1*randn(nx, nx) +const __B_ = randn(nx, nu) +const __C_ = randn(ny,nx) + +dynamics_l(x,u,p,t) = __A_*x .+ __B_*u +measurement_l(x,u,p,t) = __C_*x + +R1 = I(nx) +R2 = I(ny) + +T = 200 # Number of time steps +kf = KalmanFilter(__A_, __B_, __C_, 0, R1, R2) +skf = SqKalmanFilter(__A_, __B_, __C_, 0, R1, R2) +ukf = UnscentedKalmanFilter(dynamics_l, measurement_l, R1, R2; ny, nu) +ekf = ExtendedKalmanFilter(dynamics_l, measurement_l, R1, R2; nu) + +U = [randn(nu) for _ in 1:T] +x,u,y = LowLevelParticleFilters.simulate(kf, U) # Simuate trajectory using the model in the filter + +## Test mixing of measurement models =========================================== + +mm_ukf = UKFMeasurementModel{Float64, false, false}(measurement_l, R2; nx, ny) +mm_ekf = EKFMeasurementModel{Float64, false}(measurement_l, R2; nx, ny) +mm_kf = LinearMeasurementModel(__C_, 0, R2; nx, ny) +mm = CompositeMeasurementModel(mm_ukf, mm_ekf, mm_kf) + + +mms = [mm_ukf, mm_ekf, mm_kf, mm] + +for mm in mms + @show nameof(typeof(mm)) + + correct!(kf, mm, u[1], y[1]) + correct!(ekf, mm, u[1], y[1]) + correct!(ukf, mm, u[1], y[1]) + + @test kf.x ≈ ekf.x ≈ ukf.x + @test kf.R ≈ ekf.R ≈ ukf.R +end + + +## Filters with measurement models in them +using Plots +for mm in mms + @show nameof(typeof(mm)) + + ukf = UnscentedKalmanFilter(dynamics_l, mm, R1; ny, nu) + ekf = ExtendedKalmanFilter(dynamics_l, mm, R1; nu) + + correct!(ekf, mm, u[1], y[1]) + correct!(ukf, mm, u[1], y[1]) + + @test ekf.x ≈ ukf.x + @test ekf.R ≈ ukf.R + + sol_ukf = forward_trajectory(ukf, u, y) + sol_ekf = forward_trajectory(ekf, u, y) + plot(sol_ukf) # |> display + plot(sol_ekf) # |> display + + @test sol_ukf.x ≈ sol_ekf.x + @test sol_ukf.xt ≈ sol_ekf.xt + @test sol_ukf.R ≈ sol_ekf.R + @test sol_ukf.Rt ≈ sol_ekf.Rt +end \ No newline at end of file