From a5bf0019b7d6fe2ccc393562a035f7a05a2bfc68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Thu, 30 May 2024 16:00:57 +0200 Subject: [PATCH 1/4] don't collect observed `ArraySymbolics` If `ArraySymbolic` `is_observed` forward directly to `observed` instead of collecting first. --- src/state_indexing.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index cd16b10f..98f0c3d9 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -175,6 +175,17 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) return getu(sys, idx) elseif is_parameter(sys, sym) return getp(sys, sym) + elseif is_observed(sys, sym) + obs = observed(sys, sym isa Tuple ? collect(sym) : sym) + getter = if is_time_dependent(sys) + TimeDependentObservedFunction(obs) + else + TimeIndependentObservedFunction(obs) + end + if sym isa Tuple + getter = AsTupleWrapper(getter) + end + return getter end return getu(sys, collect(sym)) end From 05c3a2cd9c01fed0bcd3bc439640c7cd700424b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Fri, 31 May 2024 09:15:00 +0200 Subject: [PATCH 2/4] remove check for tuple --- src/state_indexing.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/state_indexing.jl b/src/state_indexing.jl index 98f0c3d9..4c6f92b8 100644 --- a/src/state_indexing.jl +++ b/src/state_indexing.jl @@ -176,16 +176,12 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym) elseif is_parameter(sys, sym) return getp(sys, sym) elseif is_observed(sys, sym) - obs = observed(sys, sym isa Tuple ? collect(sym) : sym) - getter = if is_time_dependent(sys) - TimeDependentObservedFunction(obs) + obs = observed(sys, sym) + if is_time_dependent(sys) + return TimeDependentObservedFunction(obs) else - TimeIndependentObservedFunction(obs) - end - if sym isa Tuple - getter = AsTupleWrapper(getter) + return TimeIndependentObservedFunction(obs) end - return getter end return getu(sys, collect(sym)) end From 47c85fe31abe18ebbdc76d35817fcf8a62bd16da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Fri, 31 May 2024 10:17:39 +0200 Subject: [PATCH 3/4] add test to check handling of ArraySymbolics --- test/state_indexing_test.jl | 69 +++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/test/state_indexing_test.jl b/test/state_indexing_test.jl index 16c8a718..3451c8f6 100644 --- a/test/state_indexing_test.jl +++ b/test/state_indexing_test.jl @@ -215,3 +215,72 @@ for (sym, val) in [(:a, p[1]) @inferred get(fi) @test get(fi) == val end + +# test handling of ArraySymbolics +abstract type FakeArraySymbolic end +struct UIndex{T} <: FakeArraySymbolic + val::T +end +struct PIndex{T} <: FakeArraySymbolic + val::T +end +struct OIndex{T} <: FakeArraySymbolic + val::T +end +SymbolicIndexingInterface.symbolic_type(::Type{<:FakeArraySymbolic}) = ArraySymbolic() + +struct FakeArraySymbolicSystem + u::Vector{Float64} + p::Vector{Float64} +end +Base.getindex(sys::FakeArraySymbolicSystem, sym) = getu(sys, sym)(sys) +SymbolicIndexingInterface.symbolic_container(sys::FakeArraySymbolicSystem) = sys +SymbolicIndexingInterface.is_time_dependent(sys::FakeArraySymbolicSystem) = false + +# value provider interface +SymbolicIndexingInterface.state_values(s::FakeArraySymbolicSystem) = s.u +SymbolicIndexingInterface.parameter_values(s::FakeArraySymbolicSystem) = s.p +SymbolicIndexingInterface.current_time(s::FakeArraySymbolicSystem) = nothing +function SymbolicIndexingInterface.set_state!(s::FakeArraySymbolicSystem, val, idx) + set_state!(s.u, val, idx) +end +function SymbolicIndexingInterface.set_parameter!(s::FakeArraySymbolicSystem, val, idx) + set_parameter!(s.p, val, idx) +end + +# index provider interface +SymbolicIndexingInterface.is_variable(sys::FakeArraySymbolicSystem, sym) = false +function SymbolicIndexingInterface.is_variable(sys::FakeArraySymbolicSystem, sym::UIndex) + sym.val ⊆ eachindex(sys.u) +end +SymbolicIndexingInterface.variable_index(sys::FakeArraySymbolicSystem, sym) = sym.val + +SymbolicIndexingInterface.is_parameter(sys::FakeArraySymbolicSystem, sym) = false +function SymbolicIndexingInterface.is_parameter(sys::FakeArraySymbolicSystem, sym::PIndex) + sym.val ⊆ eachindex(sys.p) +end +SymbolicIndexingInterface.parameter_index(sys::FakeArraySymbolicSystem, sym) = sym.val + +SymbolicIndexingInterface.is_observed(sys::FakeArraySymbolicSystem, sym) = false +function SymbolicIndexingInterface.is_observed(sys::FakeArraySymbolicSystem, sym::OIndex) + sym.val ⊆ eachindex(sys.p) ∩ eachindex(sys.u) +end +function SymbolicIndexingInterface.observed(sys::FakeArraySymbolicSystem, sym) + (u, p) -> u[sym.val] + p[sym.val] +end + +# create fake system +u = rand(10) +p = rand(12) +sys = FakeArraySymbolicSystem(u, p) + +# make sure we cannot collect, batched handling is required +@test_throws MethodError collect(UIndex(1:5)) +@test_throws MethodError collect(PIndex(1:5)) +@test_throws MethodError collect(OIndex(1:5)) + +# assert batched handling if is_[variable,perameter,observable] returns true +r = 1:10 +@test sys[UIndex(r)] == u[r] +@test sys[PIndex(r)] == p[r] +@test sys[OIndex(r)] == p[r] + u[r] From d0c98ba2b280be85dca7a142983765160053c5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 3 Jun 2024 11:11:58 +0200 Subject: [PATCH 4/4] SymbolicUtils@v1 for downstream tests --- test/downstream/Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 7450e66a..0111974f 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,3 +1,7 @@ [deps] SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + +[compat] +SymbolicUtils = "1.6"