Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't collect observed ArraySymbolics #81

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
return getu(sys, idx)
elseif is_parameter(sys, sym)
return getp(sys, sym)
elseif is_observed(sys, sym)
obs = observed(sys, sym)
if is_time_dependent(sys)
return TimeDependentObservedFunction(obs)
else
return TimeIndependentObservedFunction(obs)
end
end
return getu(sys, collect(sym))
end
Expand Down
4 changes: 4 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[deps]
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
SymbolicUtils = "1.6"
69 changes: 69 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,72 @@ for (sym, val) in [(:a, p[1])
@inferred get(fi)
@test get(fi) == val
end

# test handling of ArraySymbolics
abstract type FakeArraySymbolic end
struct UIndex{T} <: FakeArraySymbolic
val::T
end
struct PIndex{T} <: FakeArraySymbolic
val::T
end
struct OIndex{T} <: FakeArraySymbolic
val::T
end
SymbolicIndexingInterface.symbolic_type(::Type{<:FakeArraySymbolic}) = ArraySymbolic()

struct FakeArraySymbolicSystem
u::Vector{Float64}
p::Vector{Float64}
end
Base.getindex(sys::FakeArraySymbolicSystem, sym) = getu(sys, sym)(sys)
SymbolicIndexingInterface.symbolic_container(sys::FakeArraySymbolicSystem) = sys
SymbolicIndexingInterface.is_time_dependent(sys::FakeArraySymbolicSystem) = false

# value provider interface
SymbolicIndexingInterface.state_values(s::FakeArraySymbolicSystem) = s.u
SymbolicIndexingInterface.parameter_values(s::FakeArraySymbolicSystem) = s.p
SymbolicIndexingInterface.current_time(s::FakeArraySymbolicSystem) = nothing
function SymbolicIndexingInterface.set_state!(s::FakeArraySymbolicSystem, val, idx)
set_state!(s.u, val, idx)
end
function SymbolicIndexingInterface.set_parameter!(s::FakeArraySymbolicSystem, val, idx)
set_parameter!(s.p, val, idx)
end

# index provider interface
SymbolicIndexingInterface.is_variable(sys::FakeArraySymbolicSystem, sym) = false
function SymbolicIndexingInterface.is_variable(sys::FakeArraySymbolicSystem, sym::UIndex)
sym.val ⊆ eachindex(sys.u)
end
SymbolicIndexingInterface.variable_index(sys::FakeArraySymbolicSystem, sym) = sym.val

SymbolicIndexingInterface.is_parameter(sys::FakeArraySymbolicSystem, sym) = false
function SymbolicIndexingInterface.is_parameter(sys::FakeArraySymbolicSystem, sym::PIndex)
sym.val ⊆ eachindex(sys.p)
end
SymbolicIndexingInterface.parameter_index(sys::FakeArraySymbolicSystem, sym) = sym.val

SymbolicIndexingInterface.is_observed(sys::FakeArraySymbolicSystem, sym) = false
function SymbolicIndexingInterface.is_observed(sys::FakeArraySymbolicSystem, sym::OIndex)
sym.val ⊆ eachindex(sys.p) ∩ eachindex(sys.u)
end
function SymbolicIndexingInterface.observed(sys::FakeArraySymbolicSystem, sym)
(u, p) -> u[sym.val] + p[sym.val]
end

# create fake system
u = rand(10)
p = rand(12)
sys = FakeArraySymbolicSystem(u, p)

# make sure we cannot collect, batched handling is required
@test_throws MethodError collect(UIndex(1:5))
@test_throws MethodError collect(PIndex(1:5))
@test_throws MethodError collect(OIndex(1:5))

# assert batched handling if is_[variable,perameter,observable] returns true
r = 1:10
@test sys[UIndex(r)] == u[r]
@test sys[PIndex(r)] == p[r]
@test sys[OIndex(r)] == p[r] + u[r]