Skip to content

Commit

Permalink
allow different measurement models for EKF
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Dec 11, 2024
1 parent 21162d3 commit 03e419a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 24 deletions.
61 changes: 41 additions & 20 deletions src/ekf.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
abstract type AbstractExtendedKalmanFilter{IPD,IPM} <: AbstractKalmanFilter end
struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A, C} <: AbstractExtendedKalmanFilter{IPD,IPM}
struct ExtendedKalmanFilter{IPD, IPM, KF <: KalmanFilter, F, G, A} <: AbstractExtendedKalmanFilter{IPD,IPM}
kf::KF
dynamics::F
measurement::G
measurement_model::G
Ajac::A
Cjac::C
end

"""
Expand All @@ -28,10 +27,10 @@ See also [`UnscentedKalmanFilter`](@ref) which is typically more accurate than `
"""
ExtendedKalmanFilter

function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Matrix(R1)); nu::Int, ny=nothing, Ts = 1.0, p = NullParameters(), α = 1.0, check = true, Ajac = nothing, Cjac = nothing)
function ExtendedKalmanFilter(dynamics, measurement_model::AbstractMeasurementModel, R1,d0=SimpleMvNormal(Matrix(R1)); nu::Int, ny=measurement_model.ny, Ts = 1.0, p = NullParameters(), α = 1.0, check = true, Ajac = nothing)
nx = size(R1,1)
ny = size(R2,1)
T = eltype(R1)
R2 = measurement_model.R2
if R1 isa SMatrix
x = @SVector zeros(T, nx)
u = @SVector zeros(T, nu)
Expand All @@ -46,12 +45,28 @@ function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Mat
D = zeros(ny, nu) # This one is never needed
kf = KalmanFilter(A,B,C,D,R1,R2,d0; Ts, p, α, check)

return ExtendedKalmanFilter(kf, dynamics, measurement; Ajac, Cjac)
return ExtendedKalmanFilter(kf, dynamics, measurement_model; Ajac)
end

function ExtendedKalmanFilter(dynamics, measurement, R1,R2,d0=SimpleMvNormal(Matrix(R1)); nu::Int, ny=size(R2,1), Cjac = nothing, kwargs...)
IPM = has_ip(measurement)
T = promote_type(eltype(R1), eltype(R2), eltype(d0))
nx = size(R1,1)
measurement_model = EKFMeasurementModel{T, IPM}(measurement, R2; nx, ny, Cjac)
return ExtendedKalmanFilter(dynamics, measurement_model, R1, d0; nu, kwargs...)
end


function ExtendedKalmanFilter(kf, dynamics, measurement; Ajac = nothing, Cjac = nothing)
IPD = has_ip(dynamics)
IPM = has_ip(measurement)
if measurement isa AbstractMeasurementModel
measurement_model = measurement
IPM = isinplace(measurement_model)
else
IPM = has_ip(measurement)
T = promote_type(eltype(kf.R1), eltype(kf.R2), eltype(kf.d0))
measurement_model = EKFMeasurementModel{T, IPM}(measurement, kf.R2; kf.nx, kf.ny, Cjac)
end
if Ajac === nothing
# if IPD
# inner! = (xd,x)->dynamics(xd,x,u,p,t)
Expand All @@ -72,21 +87,21 @@ function ExtendedKalmanFilter(kf, dynamics, measurement; Ajac = nothing, Cjac =
Ajac = (x,u,p,t) -> ForwardDiff.jacobian(x->dynamics(x,u,p,t), x)
end
end
if Cjac === nothing
if IPM
outy = zeros(eltype(kf.d0), kf.ny)
jacy = zeros(eltype(kf.d0), kf.ny, kf.nx)
Cjac = (x,u,p,t) -> ForwardDiff.jacobian!(jacy, (y,x)->measurement(y,x,u,p,t), outy, x)
else
Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x)
end
end
return ExtendedKalmanFilter{IPD,IPM,typeof(kf),typeof(dynamics),typeof(measurement),typeof(Ajac),typeof(Cjac)}(kf, dynamics, measurement, Ajac, Cjac)

return ExtendedKalmanFilter{IPD,IPM,typeof(kf),typeof(dynamics),typeof(measurement_model),typeof(Ajac)}(kf, dynamics, measurement_model, Ajac)
end

function Base.getproperty(ekf::EKF, s::Symbol) where EKF <: AbstractExtendedKalmanFilter
s fieldnames(EKF) && return getfield(ekf, s)
return getproperty(getfield(ekf, :kf), s)
mm = getfield(ekf, :measurement_model)
if s fieldnames(typeof(mm))
return getfield(mm, s)
end
kf = getfield(ekf, :kf)
if s fieldnames(typeof(kf))
return getproperty(kf, s)
end
error("$(typeof(ekf)) has no property named $s")
end

function Base.setproperty!(ekf::ExtendedKalmanFilter, s::Symbol, val)
Expand Down Expand Up @@ -117,9 +132,15 @@ function predict!(kf::AbstractExtendedKalmanFilter{IPD}, u, p = parameters(kf),
kf.t += 1
end

function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.R2, kf.x, u, p, t), measurement = kf.measurement) where IPM
function correct!(ukf::AbstractExtendedKalmanFilter, u, y, p, t::Real; kwargs...)
measurement_model = ukf.measurement_model
correct!(ukf, measurement_model, u, y, p, t::Real; kwargs...)
end

function correct!(kf::AbstractExtendedKalmanFilter{<:Any, IPM}, measurement_model::EKFMeasurementModel, u, y, p = parameters(kf), t::Real = index(kf); R2 = get_mat(kf.measurement_model.R2, kf.x, u, p, t))where IPM
@unpack x,R = kf
C = kf.Cjac(x, u, p, t)
(; measurement, Cjac) = measurement_model
C = Cjac(x, u, p, t)
if IPM
e = zeros(length(y))
measurement(e, x, u, p, t)
Expand Down
84 changes: 83 additions & 1 deletion src/measurement_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct UKFMeasurementModel{IPM,AUGM,MT,RT,IT,MET,CT,CCT,CAT} <: AbstractMeasurem
cache::CAT
end

isinplace(::UKFMeasurementModel{IPM}) where IPM = IPM

UKFMeasurementModel{IPM,AUGM}(
measurement,
R2,
Expand All @@ -35,7 +37,7 @@ UKFMeasurementModel{IPM,AUGM}(
typeof(mean),
typeof(cov),
typeof(cross_cov),
cache,
typeof(cache),
}(
measurement,
R2,
Expand Down Expand Up @@ -143,3 +145,83 @@ function SigmaPointCache{T}(nx, nw, ny, L, static) where T
end

Base.eltype(spc::SigmaPointCache) = eltype(spc.x0)


## EKF measurement model =======================================================

struct EKFMeasurementModel{IPM,MT,RT,CJ,CAT} <: AbstractMeasurementModel
measurement::MT
R2::RT
ny::Int
Cjac::CJ
cache::CAT
end

isinplace(::EKFMeasurementModel{IPM}) where IPM = IPM

EKFMeasurementModel{IPM}(
measurement,
R2,
ny,
Cjac,
cache = nothing,
) where {IPM} = EKFMeasurementModel{
IPM,
typeof(measurement),
typeof(R2),
typeof(Cjac),
typeof(cache),
}(
measurement,
R2,
ny,
Cjac,
cache,
)


function add_cache(model::EKFMeasurementModel{IPM}, cache) where {IPM}
EKFMeasurementModel{eltype(model.cache),IPM}(
model.measurement,
model.R2,
model.ny,
model.Cjac,
cache,
)
end


function EKFMeasurementModel{T,IPM}(
measurement::M,
R2;
nx,
ny,
Cjac = nothing,
) where {T,IPM,M}


if Cjac === nothing
if IPM
outy = zeros(T, ny)
jacy = zeros(T, ny, nx)
Cjac = (x,u,p,t) -> ForwardDiff.jacobian!(jacy, (y,x)->measurement(y,x,u,p,t), outy, x)
else
Cjac = (x,u,p,t) -> ForwardDiff.jacobian(x->measurement(x,u,p,t), x)
end
end


EKFMeasurementModel{
IPM,
typeof(measurement),
typeof(R2),
typeof(Cjac),
typeof(nothing),
}(
measurement,
R2,
ny,
Cjac,
nothing,
)
end
6 changes: 3 additions & 3 deletions src/ukf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ function correct!(
ym = mean(ys)
C = measurement_cov(xsm, x, ys, ym)
e = innovation(y, ym)
S = compute_S(measurement_model, R2)
S = compute_S(measurement_model, R2, ym)
Sᵪ = cholesky(Symmetric(S); check = false)
issuccess(Sᵪ) ||
error("Cholesky factorization of innovation covariance failed, got S = ", S)
Expand Down Expand Up @@ -452,11 +452,11 @@ function propagate_sigmapoints_c!(
end
end

function compute_S(measurement_model::UKFMeasurementModel{<:Any, AUGM}, R2) where AUGM
function compute_S(measurement_model::UKFMeasurementModel{<:Any, AUGM}, R2, ym) where AUGM
sigma_point_cache = measurement_model.cache
ys = sigma_point_cache.x1
cov = measurement_model.cov
S = symmetrize(cov(ys))
S = symmetrize(cov(ys, ym))
if !AUGM
if S isa SMatrix || S isa Symmetric{<:Any,<:SMatrix}
S += R2
Expand Down

0 comments on commit 03e419a

Please sign in to comment.