From 03e419a23eaab8b0580c1642c6fd8c9e938326ac Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Wed, 11 Dec 2024 09:39:40 +0100 Subject: [PATCH] allow different measurement models for EKF --- src/ekf.jl | 61 +++++++++++++++++++---------- src/measurement_model.jl | 84 +++++++++++++++++++++++++++++++++++++++- src/ukf.jl | 6 +-- 3 files changed, 127 insertions(+), 24 deletions(-) diff --git a/src/ekf.jl b/src/ekf.jl index 2b88309..77fd40b 100644 --- a/src/ekf.jl +++ b/src/ekf.jl @@ -1,10 +1,9 @@ abstract type AbstractExtendedKalmanFilter{IPD,IPM} <: AbstractKalmanFilter end -struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A, C} <: AbstractExtendedKalmanFilter{IPD,IPM} +struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A} <: AbstractExtendedKalmanFilter{IPD,IPM} kf::KF dynamics::F - measurement::G + measurement_model::G Ajac::A - Cjac::C end """ @@ -28,10 +27,10 @@ See also [`UnscentedKalmanFilter`](@ref) which is typically more accurate than ` """ ExtendedKalmanFilter -function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Matrix(R1)); nu::Int, ny=nothing, Ts = 1.0, p = NullParameters(), α = 1.0, check = true, Ajac = nothing, Cjac = nothing) +function ExtendedKalmanFilter(dynamics, measurement_model::AbstractMeasurementModel, R1,d0=SimpleMvNormal(Matrix(R1)); nu::Int, ny=measurement_model.ny, Ts = 1.0, p = NullParameters(), α = 1.0, check = true, Ajac = nothing) nx = size(R1,1) - ny = size(R2,1) T = eltype(R1) + R2 = measurement_model.R2 if R1 isa SMatrix x = @SVector zeros(T, nx) u = @SVector zeros(T, nu) @@ -46,12 +45,28 @@ function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Mat D = zeros(ny, nu) # This one is never needed kf = KalmanFilter(A,B,C,D,R1,R2,d0; Ts, p, α, check) - return ExtendedKalmanFilter(kf, dynamics, measurement; Ajac, Cjac) + return ExtendedKalmanFilter(kf, dynamics, measurement_model; Ajac) end +function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Matrix(R1)); nu::Int, ny=size(R2,1), Cjac = nothing, kwargs...) + IPM = has_ip(measurement) + T = promote_type(eltype(R1), eltype(R2), eltype(d0)) + nx = size(R1,1) + measurement_model = EKFMeasurementModel{T, IPM}(measurement, R2; nx, ny, Cjac) + return ExtendedKalmanFilter(dynamics, measurement_model, R1, d0; nu, kwargs...) +end + + function ExtendedKalmanFilter(kf, dynamics, measurement; Ajac = nothing, Cjac = nothing) IPD = has_ip(dynamics) - IPM = has_ip(measurement) + if measurement isa AbstractMeasurementModel + measurement_model = measurement + IPM = isinplace(measurement_model) + else + IPM = has_ip(measurement) + T = promote_type(eltype(kf.R1), eltype(kf.R2), eltype(kf.d0)) + measurement_model = EKFMeasurementModel{T, IPM}(measurement, kf.R2; kf.nx, kf.ny, Cjac) + end if Ajac === nothing # if IPD # inner! = (xd,x)->dynamics(xd,x,u,p,t) @@ -72,21 +87,21 @@ function ExtendedKalmanFilter(kf, dynamics, measurement; Ajac = nothing, Cjac = Ajac = (x,u,p,t) -> ForwardDiff.jacobian(x->dynamics(x,u,p,t), x) end end - if Cjac === nothing - if IPM - outy = zeros(eltype(kf.d0), kf.ny) - jacy = zeros(eltype(kf.d0), kf.ny, kf.nx) - Cjac = (x,u,p,t) -> ForwardDiff.jacobian!(jacy, (y,x)->measurement(y,x,u,p,t), outy, x) - else - Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x) - end - end - return ExtendedKalmanFilter{IPD,IPM,typeof(kf),typeof(dynamics),typeof(measurement),typeof(Ajac),typeof(Cjac)}(kf, dynamics, measurement, Ajac, Cjac) + + return ExtendedKalmanFilter{IPD,IPM,typeof(kf),typeof(dynamics),typeof(measurement_model),typeof(Ajac)}(kf, dynamics, measurement_model, Ajac) end function Base.getproperty(ekf::EKF, s::Symbol) where EKF <: AbstractExtendedKalmanFilter s ∈ fieldnames(EKF) && return getfield(ekf, s) - return getproperty(getfield(ekf, :kf), s) + mm = getfield(ekf, :measurement_model) + if s ∈ fieldnames(typeof(mm)) + return getfield(mm, s) + end + kf = getfield(ekf, :kf) + if s ∈ fieldnames(typeof(kf)) + return getproperty(kf, s) + end + error("$(typeof(ekf)) has no property named $s") end function Base.setproperty!(ekf::ExtendedKalmanFilter, s::Symbol, val) @@ -117,9 +132,15 @@ function predict!(kf::AbstractExtendedKalmanFilter{IPD}, u, p = parameters(kf), kf.t += 1 end -function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.R2, kf.x, u, p, t), measurement = kf.measurement) where IPM +function correct!(ukf::AbstractExtendedKalmanFilter, u, y, p, t::Real; kwargs...) + measurement_model = ukf.measurement_model + correct!(ukf, measurement_model, u, y, p, t::Real; kwargs...) +end + +function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, measurement_model::EKFMeasurementModel, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.measurement_model.R2, kf.x, u, p, t))where IPM @unpack x,R = kf - C = kf.Cjac(x, u, p, t) + (; measurement, Cjac) = measurement_model + C = Cjac(x, u, p, t) if IPM e = zeros(length(y)) measurement(e, x, u, p, t) diff --git a/src/measurement_model.jl b/src/measurement_model.jl index 3e99b01..febcd8e 100644 --- a/src/measurement_model.jl +++ b/src/measurement_model.jl @@ -16,6 +16,8 @@ struct UKFMeasurementModel{IPM,AUGM,MT,RT,IT,MET,CT,CCT,CAT} <: AbstractMeasurem cache::CAT end +isinplace(::UKFMeasurementModel{IPM}) where IPM = IPM + UKFMeasurementModel{IPM,AUGM}( measurement, R2, @@ -35,7 +37,7 @@ UKFMeasurementModel{IPM,AUGM}( typeof(mean), typeof(cov), typeof(cross_cov), - cache, + typeof(cache), }( measurement, R2, @@ -143,3 +145,83 @@ function SigmaPointCache{T}(nx, nw, ny, L, static) where T end Base.eltype(spc::SigmaPointCache) = eltype(spc.x0) + + +## EKF measurement model ======================================================= + +struct EKFMeasurementModel{IPM,MT,RT,CJ,CAT} <: AbstractMeasurementModel + measurement::MT + R2::RT + ny::Int + Cjac::CJ + cache::CAT +end + +isinplace(::EKFMeasurementModel{IPM}) where IPM = IPM + +EKFMeasurementModel{IPM}( + measurement, + R2, + ny, + Cjac, + cache = nothing, +) where {IPM} = EKFMeasurementModel{ + IPM, + typeof(measurement), + typeof(R2), + typeof(Cjac), + typeof(cache), +}( + measurement, + R2, + ny, + Cjac, + cache, +) + + +function add_cache(model::EKFMeasurementModel{IPM}, cache) where {IPM} + EKFMeasurementModel{eltype(model.cache),IPM}( + model.measurement, + model.R2, + model.ny, + model.Cjac, + cache, + ) +end + + +function EKFMeasurementModel{T,IPM}( + measurement::M, + R2; + nx, + ny, + Cjac = nothing, +) where {T,IPM,M} + + + if Cjac === nothing + if IPM + outy = zeros(T, ny) + jacy = zeros(T, ny, nx) + Cjac = (x,u,p,t) -> ForwardDiff.jacobian!(jacy, (y,x)->measurement(y,x,u,p,t), outy, x) + else + Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x) + end + end + + + EKFMeasurementModel{ + IPM, + typeof(measurement), + typeof(R2), + typeof(Cjac), + typeof(nothing), + }( + measurement, + R2, + ny, + Cjac, + nothing, + ) +end \ No newline at end of file diff --git a/src/ukf.jl b/src/ukf.jl index a5f1f38..e87a50d 100644 --- a/src/ukf.jl +++ b/src/ukf.jl @@ -361,7 +361,7 @@ function correct!( ym = mean(ys) C = measurement_cov(xsm, x, ys, ym) e = innovation(y, ym) - S = compute_S(measurement_model, R2) + S = compute_S(measurement_model, R2, ym) Sᵪ = cholesky(Symmetric(S); check = false) issuccess(Sᵪ) || error("Cholesky factorization of innovation covariance failed, got S = ", S) @@ -452,11 +452,11 @@ function propagate_sigmapoints_c!( end end -function compute_S(measurement_model::UKFMeasurementModel{<:Any, AUGM}, R2) where AUGM +function compute_S(measurement_model::UKFMeasurementModel{<:Any, AUGM}, R2, ym) where AUGM sigma_point_cache = measurement_model.cache ys = sigma_point_cache.x1 cov = measurement_model.cov - S = symmetrize(cov(ys)) + S = symmetrize(cov(ys, ym)) if !AUGM if S isa SMatrix || S isa Symmetric{<:Any,<:SMatrix} S += R2