Skip to content

Commit

Permalink
add adaptive neural-network tutorial
Browse files Browse the repository at this point in the history
and add prediction errors to kalman solution
  • Loading branch information
baggepinnen committed Nov 21, 2024
1 parent 1904666 commit f87943e
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 11 deletions.
11 changes: 11 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DisplayAs = "0b91fe84-8a4c-11e9-3e1d-67c38462b6d6"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LowLevelParticleFilters = "d9d29d28-c116-5dba-9239-57a5fe23875b"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SeeToDee = "1c904df7-48cd-41e7-921b-d889ed9a470c"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
julia = "1.6"
Lux = "1.3"
SparseConnectivityTracer = "0.6.8"
SparseMatrixColorings = "0.4.10"
DifferentiationInterface = "0.6.23"
ComponentArrays = "0.15.19"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ makedocs(
"Noise tuning and disturbance modeling for Kalman filtering" => "noisetuning.md",
"Particle-filter tutorial" => "beetle_example.md",
"State estimation for DAE systems" => "dae.md",
"Adaptive Neural-Network training" => "neural_network.md",
],
"API" => "api.md",
],
Expand Down
197 changes: 197 additions & 0 deletions docs/src/neural_network.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Adaptive Neural-Network training
In this example, we will demonstrate hwo we can take the estimation of time-varying parameters to the extreme, and use a nonlinear state estimator to estimate the weights in a neural-network model of a dynamical system.

In the tutorial [Joint state and parameter estimation](@ref), we demonstrated how we can add a parameter as a state variable and let the state estimator estimate this alongside the state. In this example, we will try to learn a black-box model of the system dynamics using a neural network, and treat the network weights as time-varying parameters by adding them to the state.

We start by generating some data from a simple dynamical system, we will continue to use the quadruple-tank system from [Joint state and parameter estimation](@ref).

```@example ADAPTIVE_NN
using LowLevelParticleFilters, Lux, Random, SeeToDee, StaticArrays, Plots, LinearAlgebra, ComponentArrays, DifferentiationInterface, SparseMatrixColorings
using SparseConnectivityTracer: TracerSparsityDetector
using DisplayAs # hide
using LowLevelParticleFilters: SimpleMvNormal
function quadtank(h,u,p,t)
kc = 0.5
k1, k2, g = 1.6, 1.6, 9.81
A1 = A3 = A2 = A4 = 4.9
a1, a3, a2, a4 = 0.03, 0.03, 0.03, 0.03
γ1, γ2 = 0.2, 0.2
if t > 2000
a1 *= 1.5 # Change the parameter at t = 2000
end
ssqrt(x) = √(max(x, zero(x)) + 1e-3) # For numerical robustness at x = 0
SA[
-a1/A1 * ssqrt(2g*h[1]) + a3/A1*ssqrt(2g*h[3]) + γ1*k1/A1 * u[1]
-a2/A2 * ssqrt(2g*h[2]) + a4/A2*ssqrt(2g*h[4]) + γ2*k2/A2 * u[2]
-a3/A3*ssqrt(2g*h[3]) + (1-γ2)*k2/A3 * u[2]
-a4/A4*ssqrt(2g*h[4]) + (1-γ1)*k1/A4 * u[1]
]
end
Ts = 30 # sample time
discrete_dynamics = SeeToDee.Rk4(quadtank, Ts) # Discretize dynamics
nu = 2 # number of control inputs
nx = 4 # number of state variables
ny = 4 # number of measured outputs
function generate_data()
measurement(x,u,p,t) = x#SA[x[1], x[2]]
Tperiod = 200
t = 0:Ts:4000
u = vcat.((0.25 .* sign.(sin.(2pi/Tperiod .* t)) .+ 0.25) .* sqrt.(rand.()))
u = SVector{nu, Float32}.(vcat.(u,u))
x0 = Float32[2,2,3,3]
x = LowLevelParticleFilters.rollout(discrete_dynamics, x0, u)[1:end-1]
y = measurement.(x, u, 0, 0)
y = [Float32.(y .+ 0.01.*randn.()) for y in y] # Add some noise to the measurement
(; x, u, y, nx, nu, ny, Ts)
end
rng = Random.default_rng()
Random.seed!(rng, 1)
data = generate_data()
nothing # hide
```


## Neural network dynamics
Our neural network will be a small feedforward network built using the package [Lux.jl](https://lux.csail.mit.edu/stable/tutorials/beginner/5_OptimizationIntegration).

```@example ADAPTIVE_NN
ni = ny + nu
nhidden = 8
const model_ = Chain(Dense(ni, nhidden, tanh), Chain(Dense(nhidden, nhidden, tanh), Dense(nhidden, ny)))
```

Since the network is rather small, we will train on the CPU only, this will be fast enough for this use case. We may extract the parameters of the network using the function `Lux.setup`, and convert them to a ComponentArray to make it easier to refer to different parts of the combined state vector.
```@example ADAPTIVE_NN
dev = cpu_device()
ps, st = Lux.setup(rng, model_) |> dev
parr = ComponentArray(ps)
nothing # hide
```

The dynamics of our black-box model will call the neural network to predict the next state given the current state and input. We bias the dynamics towards low frequencies by adding a multiple of the current state to the prediction of the next state, `0.95*x`. We also add a small amount of weight decay to the parameters of the neural network for regularization, `0.995*p`.
```@example ADAPTIVE_NN
function dynamics(out0, xp0, u, _, t)
xp = ComponentArray(xp0, getaxes(s0))
out = ComponentArray(out0, getaxes(s0))
x = xp.x
p = xp.p
xp, _ = Lux.apply(model_, [x; u], p, st)
@. out.x = 0.95f0*x+xp
@. out.p = 0.995f0*p
nothing
end
@views measurement(out, x, _, _, _) = out .= x[1:nx] # Assume measurement of the full state vector
nothing # hide
```

For simplicity, we have assumed here that we have access to measurements of the entire state vector of the original process. This is many times unrealistic, and if we do not have such access, we may instead augment the measured signals with delayed versions of themselves (sometimes called a delay embedding). This is a common technique in discrete-time system identification, used in e.g., `ControlSystemIdentification.arx` and `subspaceid`.

The initial state of the process `x0` and the initial parameters of the neural network `parr` can now be concatenated to form the initial augmented state `s0`.
```@example ADAPTIVE_NN
x0 = Float32[2; 2; 3; 3]
s0 = ComponentVector(; x=x0, p=parr)
nothing # hide
```

## Kalman filter setup
We will estimate the parameters using two different nonlinear Kalman filters, the [`ExtendedKalmanFilter`](@ref) and the [`UnscentedKalmanFilter`](@ref). The covariance matrices for the filters, `R1, R2`, may be tuned such that we get the desired learning speed of the weights, where larger covariance for the network weights will allow for faster learning, but also more noise in the estimates.
```@example ADAPTIVE_NN
R1 = Diagonal([0.1ones(nx); 0.01ones(length(parr))]) .|> Float32
R2 = Diagonal((1e-2)^2 * ones(ny)) .|> Float32
nothing # hide
```

The [`ExtendedKalmanFilter`](@ref) uses Jacobians of the dynamics and measurement model, and if we do not provide those functions they will be automatically computed using ForwardDiff.jl. Since our Jacobians will be relatively large but sparse in this example, we will make use of the sparsity-aware features of DifferentiationInterface.jl in order to get efficient Jacobian computations.
```@example ADAPTIVE_NN
function Ajacfun(x,u,p,t) # Function that returns a function for the Jacobian of the dynamics
# For large neural networks, it might be faster to use an OOP formulation with Zygote instead of ForwardDiff. Zygote does not handle the in-place version
backend = AutoSparse(
AutoForwardDiff(),
# AutoZygote(),
sparsity_detector=TracerSparsityDetector(),
coloring_algorithm=GreedyColoringAlgorithm(),
)
out = similar(getdata(x))
inner = (out,x)->dynamics(out,x,u,p,t)
prep = prepare_jacobian(inner, out, backend, getdata(x))
jac = one(eltype(x)) .* sparsity_pattern(prep)
function (x,u,p,t)
inner2 = (out,x)->dynamics(out,x,u,p,t)
DifferentiationInterface.jacobian!(inner2, out, jac, prep, backend, x)
end
end
Ajac = Ajacfun(s0, data.u[1], nothing, 0)
const CJ_ = [I(nx) zeros(Float32, nx, length(parr))] # The jacobian of the measurement model is constant
Cjac(x,u,p,t) = CJ_
nothing # hide
```

## Estimation
We may now initialize our filters and perform the estimation. Here, we use the function [`forward_trajectory`](@ref) to perform filtering along the entire data trajectory at once, but we may use this in a streaming fashion as well, as more data becomes available in real time.

We plot the one-step ahead prediction of the outputs and compare to the "measured" data.
```@example ADAPTIVE_NN
ekf = ExtendedKalmanFilter(dynamics, measurement, R1, R2, SimpleMvNormal(s0, 100R1); nu, check=false, Ajac, Cjac, Ts)
ukf = UnscentedKalmanFilter(dynamics, measurement, R1, R2, SimpleMvNormal(s0, 100R1); nu, ny, Ts)
@time sole = forward_trajectory(ekf, data.u, data.x)
@time solu = forward_trajectory(ukf, data.u, data.x)
plot(sole, plotx=false, plotxt=false, plotyh=true, plotyht=false, plotu=false, plote=true, name="EKF", layout=(nx, 1))
plot!(solu, plotx=false, plotxt=false, plotyh=true, plotyht=false, plotu=false, plote=true, name="UKF", ploty=false, size=(1200, 1500))
DisplayAs.PNG(Plots.current()) # hide
```

We see that prediction errors, $e$, are large in the beginning when the network weights are randomly initialized, but after about half the trajectory the errors are significantly reduced. Just like in the tutorial [Joint state and parameter estimation](@ref), we modified the true dynamics after some time, at $t=2000$, and we see that the filters are able to adapt to this change after a transient increase in prediction error variance.

We may also plot the evolution of the neural-network weights over time, and see how the filters adapt to the changing dynamics of the system.
```@example ADAPTIVE_NN
plot(
plot(0:Ts:4000, reduce(hcat, sole.xt)'[:, nx+1:end], title="EKF parameters"),
plot(0:Ts:4000, reduce(hcat, solu.xt)'[:, nx+1:end], title="UKF parameters"),
legend = false,
)
DisplayAs.PNG(Plots.current()) # hide
```

## Benchmarking
The neural network used in this example has
```@example ADAPTIVE_NN
length(parr)
```
parameters, and the length of the data is
```@example ADAPTIVE_NN
length(data.u)
```

Performing the estimation using the Extended Kalman Filter took
```julia
using BenchmarkTools
@btime forward_trajectory(ekf, data.u, data.x);
# 46.034 ms (77872 allocations: 123.45 MiB)
```
and with the Unscented Kalman Filter
```julia
@btime forward_trajectory(ukf, data.u, data.x);
142.608 ms (2134370 allocations: 224.82 MiB)
```

The EKF is a bit faster, which is to be expected. Both methods are very fast from a neural-network training perspective, but the performance will not scale favorably to very large network sizes.

## Closing remarks

We have seen how to estimate train a black-box neural network dynamics model by treating the parameter estimation as a state-estimation problem. This example is very simple and leaves a lot of room for improvement, such as
- We assumed very little prior knowledge of the dynamics. In practice, we may want to model as much as possible from first principles and add a neural network to capture only the residuals that out first-principles model cannot capture.
- We used forward-mode AD to compute the Jacobian. The Jacobian of the dynamics has dense rows, which means that it's theoretically favorable to use reverse-mode AD to compute it. This is possible using Zygote.jl, but Zygote does not handle array mutation, and one must thus avoid the in-place version of the dynamics. Since the number of parameters in this example is small, sparse forward mode AD ended up being slightly faster.
7 changes: 5 additions & 2 deletions src/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,20 @@ function forward_trajectory(kf::AbstractKalmanFilter, u::AbstractVector, y::Abst
xt = Array{particletype(kf)}(undef,T)
R = Array{covtype(kf)}(undef,T)
Rt = Array{covtype(kf)}(undef,T)
e = similar(y)
ll = zero(eltype(particletype(kf)))
for t = 1:T
ti = (t-1)*kf.Ts
x[t] = state(kf) |> copy
R[t] = covariance(kf) |> copy
ll += correct!(kf, u[t], y[t], p, ti)[1]
lli, ei = correct!(kf, u[t], y[t], p, ti)
ll += lli
e[t] = ei
xt[t] = state(kf) |> copy
Rt[t] = covariance(kf) |> copy
predict!(kf, u[t], p, ti)
end
KalmanFilteringSolution(kf,u,y,x,xt,R,Rt,ll)
KalmanFilteringSolution(kf,u,y,x,xt,R,Rt,ll,e)
end


Expand Down
31 changes: 26 additions & 5 deletions src/solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@ abstract type AbstractFilteringSolution end
- `R`: predicted covariance matrices ``R(t+1|t)``
- `Rt`: filter covariances ``R(t|t)``
- `ll`: loglikelihood
- `e`: prediction errors
# Plot
The solution object can be plotted
```
plot(sol, plotx=true, plotxt=true, plotu=true, ploty=true, plotyh=true, plotyht=false, name="")
```
where `plotx`, `plotxt`, `plotu`, `ploty`, `plotyh`, `plotyht` are booleans that control which plots are shown. `name` is a string that is prepended to the labels of the plots, which is useful when plotting multiple solutions in the same plot.
"""
struct KalmanFilteringSolution{F,Tu,Ty,Tx,Txt,TR,TRt,Tll} <: AbstractFilteringSolution
struct KalmanFilteringSolution{F,Tu,Ty,Tx,Txt,TR,TRt,Tll,Te} <: AbstractFilteringSolution
f::F
u::Tu
y::Ty
Expand All @@ -19,13 +27,14 @@ struct KalmanFilteringSolution{F,Tu,Ty,Tx,Txt,TR,TRt,Tll} <: AbstractFilteringSo
R::TR
Rt::TRt
ll::Tll
e::Te
end

@recipe function plot(timevec::AbstractVector{<:Real}, sol::KalmanFilteringSolution; plotx = true, plotxt=true, plotu=true, ploty=true, plotyh=false, plotyht=true, name = "")
@recipe function plot(timevec::AbstractVector{<:Real}, sol::KalmanFilteringSolution; plotx = true, plotxt=true, plotu=true, ploty=true, plotyh=true, plotyht=false, plote=false, name = "")
isempty(name) || (name = name*" ")
kf = sol.f
nx, nu, ny = length(sol.x[1]), length(sol.u[1]), length(sol.y[1])
layout --> nx*(plotx || plotxt) + plotu*nu + (ploty || plotyh || plotyht)*ny
layout --> nx*(plotx || plotxt) + plotu*nu + (ploty || plotyh || plotyht || plote)*ny
plotx && @series begin
label --> ["$(name)x$(i)(t|t-1)" for i in 1:nx] |> permutedims
subplot --> (1:nx)'
Expand All @@ -47,19 +56,25 @@ end
timevec, reduce(hcat, sol.y)'
end
plotyh && @series begin
label --> ["$(i)(t|t-1)" for i in 1:ny] |> permutedims
label --> ["$(name)$(i)(t|t-1)" for i in 1:ny] |> permutedims
subplot --> (1:ny)' .+ (nx*(plotx || plotxt) + nu*plotu)
linestyle --> :dash
yh = measurement_oop(kf).(sol.x, sol.u, Ref(kf.p), timevec)
timevec, reduce(hcat, yh)'
end
plotyht && @series begin
label --> ["$(i)(t|t)" for i in 1:ny] |> permutedims
label --> ["$(name)$(i)(t|t)" for i in 1:ny] |> permutedims
subplot --> (1:ny)' .+ (nx*(plotx || plotxt) + nu*plotu)
linestyle --> :dash
yht = measurement_oop(kf).(sol.xt, sol.u, Ref(kf.p), timevec)
timevec, reduce(hcat, yht)'
end
plote && @series begin
label --> ["$(name)e$(i)(t|t-1)" for i in 1:ny] |> permutedims
subplot --> (1:ny)' .+ (nx*(plotx || plotxt) + nu*plotu)
linestyle --> :dash
timevec, reduce(hcat, sol.e)'
end
end

@recipe function plot(sol::KalmanFilteringSolution)
Expand All @@ -78,6 +93,12 @@ end
- `w`: Weights (log space). These are used for internal computations.
- `we`: Weights (exponentiated / original space). These are the ones to use to compute weighted means etc., they sum to one for each time step.
- `ll`: Log likelihood
# Plot
The solution object can be plotted
```
plot(sol; nbinsy=30, xreal=nothing, dim=nothing, ploty=true)
```
"""
struct ParticleFilteringSolution{F,Tu,Ty,Tx,Tw,Twe,Tll} <: AbstractFilteringSolution
f::F
Expand Down
8 changes: 4 additions & 4 deletions test/test_large.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ a = @allocated forward_trajectory(ekf, u, y)

## Plotting ====================================================================
using Plots
plot(sol_kf)
plot(sol_ukf)
plot(sol_ekf)
plot(sol_sqkf)
plot(sol_kf, plothy = true, plote = true)
plot(sol_ukf, plothy = true, plote = true)
plot(sol_ekf, plothy = true, plote = true)
plot(sol_sqkf, plothy = true, plote = true)

0 comments on commit f87943e

Please sign in to comment.