diff --git a/src/state_indexing.jl b/src/state_indexing.jl index cd16b10f..4c6f92b8 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -175,6 +175,13 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) return getu(sys, idx) elseif is_parameter(sys, sym) return getp(sys, sym) + elseif is_observed(sys, sym) + obs = observed(sys, sym) + if is_time_dependent(sys) + return TimeDependentObservedFunction(obs) + else + return TimeIndependentObservedFunction(obs) + end end return getu(sys, collect(sym)) end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 7450e66a..0111974f 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,3 +1,7 @@ [deps] SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + +[compat] +SymbolicUtils = "1.6" diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 16c8a718..3451c8f6 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -215,3 +215,72 @@ for (sym, val) in [(:a, p[1]) @inferred get(fi) @test get(fi) == val end + +# test handling of ArraySymbolics +abstract type FakeArraySymbolic end +struct UIndex{T} <: FakeArraySymbolic + val::T +end +struct PIndex{T} <: FakeArraySymbolic + val::T +end +struct OIndex{T} <: FakeArraySymbolic + val::T +end +SymbolicIndexingInterface.symbolic_type(::Type{<:FakeArraySymbolic}) = ArraySymbolic() + +struct FakeArraySymbolicSystem + u::Vector{Float64} + p::Vector{Float64} +end +Base.getindex(sys::FakeArraySymbolicSystem, sym) = getu(sys, sym)(sys) +SymbolicIndexingInterface.symbolic_container(sys::FakeArraySymbolicSystem) = sys +SymbolicIndexingInterface.is_time_dependent(sys::FakeArraySymbolicSystem) = false + +# value provider interface +SymbolicIndexingInterface.state_values(s::FakeArraySymbolicSystem) = s.u +SymbolicIndexingInterface.parameter_values(s::FakeArraySymbolicSystem) = s.p +SymbolicIndexingInterface.current_time(s::FakeArraySymbolicSystem) = nothing +function SymbolicIndexingInterface.set_state!(s::FakeArraySymbolicSystem, val, idx) + set_state!(s.u, val, idx) +end +function SymbolicIndexingInterface.set_parameter!(s::FakeArraySymbolicSystem, val, idx) + set_parameter!(s.p, val, idx) +end + +# index provider interface +SymbolicIndexingInterface.is_variable(sys::FakeArraySymbolicSystem, sym) = false +function SymbolicIndexingInterface.is_variable(sys::FakeArraySymbolicSystem, sym::UIndex) + sym.val ⊆ eachindex(sys.u) +end +SymbolicIndexingInterface.variable_index(sys::FakeArraySymbolicSystem, sym) = sym.val + +SymbolicIndexingInterface.is_parameter(sys::FakeArraySymbolicSystem, sym) = false +function SymbolicIndexingInterface.is_parameter(sys::FakeArraySymbolicSystem, sym::PIndex) + sym.val ⊆ eachindex(sys.p) +end +SymbolicIndexingInterface.parameter_index(sys::FakeArraySymbolicSystem, sym) = sym.val + +SymbolicIndexingInterface.is_observed(sys::FakeArraySymbolicSystem, sym) = false +function SymbolicIndexingInterface.is_observed(sys::FakeArraySymbolicSystem, sym::OIndex) + sym.val ⊆ eachindex(sys.p) ∩ eachindex(sys.u) +end +function SymbolicIndexingInterface.observed(sys::FakeArraySymbolicSystem, sym) + (u, p) -> u[sym.val] + p[sym.val] +end + +# create fake system +u = rand(10) +p = rand(12) +sys = FakeArraySymbolicSystem(u, p) + +# make sure we cannot collect, batched handling is required +@test_throws MethodError collect(UIndex(1:5)) +@test_throws MethodError collect(PIndex(1:5)) +@test_throws MethodError collect(OIndex(1:5)) + +# assert batched handling if is_[variable,perameter,observable] returns true +r = 1:10 +@test sys[UIndex(r)] == u[r] +@test sys[PIndex(r)] == p[r] +@test sys[OIndex(r)] == p[r] + u[r]