Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Expronicon with Moshi #3354

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Expand All @@ -36,6 +35,7 @@ Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down Expand Up @@ -101,7 +101,6 @@ DomainSets = "0.6, 0.7"
DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
EnumX = "1.0.4"
ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
Expand All @@ -118,6 +117,7 @@ Libdl = "1"
LinearAlgebra = "1"
MLStyle = "0.4.17"
ModelingToolkitStandardLibrary = "2.19"
Moshi = "0.3"
NaNMath = "0.3, 1"
NonlinearSolve = "4.3"
OffsetArrays = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, Ti
using Distributed
import JuliaFormatter
using MLStyle
import Moshi
using Moshi.Data: @data
using NonlinearSolve
import SCCNonlinearSolve
using Reexport
Expand Down
26 changes: 9 additions & 17 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
module InferredClock

export InferredTimeDomain

using Expronicon.ADT: @adt, @match
using SciMLBase: TimeDomain

@adt InferredTimeDomain begin
@data InferredClock begin
Inferred
InferredDiscrete
end

Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)
const InferredTimeDomain = InferredClock.Type
using .InferredClock: Inferred, InferredDiscrete

end

using .InferredClock
Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)

struct VariableTimeDomain end
Symbolics.option_to_metadata_type(::Val{:timedomain}) = VariableTimeDomain
Expand All @@ -29,7 +21,7 @@ true if `x` contains only continuous-domain signals.
See also [`has_continuous_domain`](@ref)
"""
function is_continuous_domain(x)
issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous
issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous()
!has_discrete_domain(x) && has_continuous_domain(x)
end

Expand Down Expand Up @@ -58,8 +50,8 @@ has_time_domain(x::Num) = has_time_domain(value(x))
has_time_domain(x) = false

for op in [Differential]
@eval input_timedomain(::$op, arg = nothing) = Continuous
@eval output_timedomain(::$op, arg = nothing) = Continuous
@eval input_timedomain(::$op, arg = nothing) = Continuous()
@eval output_timedomain(::$op, arg = nothing) = Continuous()
end

"""
Expand Down Expand Up @@ -104,8 +96,8 @@ function is_discrete_domain(x)
!has_discrete_domain(x) && has_continuous_domain(x)
end

sampletime(c) = @match c begin
PeriodicClock(dt, _...) => dt
sampletime(c) = Moshi.Match.@match c begin
PeriodicClock(dt) => dt
_ => nothing
end

Expand Down
12 changes: 6 additions & 6 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,28 +226,28 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
"""
input_timedomain(op::Operator)

Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` operates on.
Return the time-domain type (`Continuous()` or `InferredDiscrete()`) that `op` operates on.
"""
function input_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete
InferredDiscrete()
end

"""
output_timedomain(op::Operator)

Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` results in.
Return the time-domain type (`Continuous()` or `InferredDiscrete()`) that `op` results in.
"""
function output_timedomain(s::Shift, arg = nothing)
if has_time_domain(t, arg)
return get_time_domain(t, arg)
end
InferredDiscrete
InferredDiscrete()
end

input_timedomain(::Sample, _ = nothing) = Continuous
input_timedomain(::Sample, _ = nothing) = Continuous()
output_timedomain(s::Sample, _ = nothing) = s.clock

function input_timedomain(h::Hold, arg = nothing)
Expand All @@ -256,7 +256,7 @@ function input_timedomain(h::Hold, arg = nothing)
end
InferredDiscrete # the Hold accepts any discrete
end
output_timedomain(::Hold, _ = nothing) = Continuous
output_timedomain(::Hold, _ = nothing) = Continuous()

sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)
Expand Down
6 changes: 3 additions & 3 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ end
function ClockInference(ts::TransformationState)
@unpack structure = ts
@unpack graph = structure
eq_domain = TimeDomain[Continuous for _ in 1:nsrcs(graph)]
var_domain = TimeDomain[Continuous for _ in 1:ndsts(graph)]
eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)]
var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)]
inferred = BitSet()
for (i, v) in enumerate(get_fullvars(ts))
d = get_time_domain(ts, v)
Expand Down Expand Up @@ -151,7 +151,7 @@ function split_system(ci::ClockInference{S}) where {S}
get!(clock_to_id, d) do
cid = (cid_counter[] += 1)
push!(id_to_clock, d)
if d == Continuous
if d == Continuous()
continuous_id[] = cid
end
cid
Expand Down
2 changes: 1 addition & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous()))
for sym in get_ps(sys)]
@set! sys.ps = ps
else
Expand Down
30 changes: 15 additions & 15 deletions test/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,19 @@ k = ShiftIndex(d)

d = Clock(dt)
# Note that TearingState reorders the equations
@test eqmap[1] == Continuous
@test eqmap[1] == Continuous()
@test eqmap[2] == d
@test eqmap[3] == d
@test eqmap[4] == d
@test eqmap[5] == Continuous
@test eqmap[6] == Continuous
@test eqmap[5] == Continuous()
@test eqmap[6] == Continuous()

@test varmap[yd] == d
@test varmap[ud] == d
@test varmap[r] == d
@test varmap[x] == Continuous
@test varmap[y] == Continuous
@test varmap[u] == Continuous
@test varmap[x] == Continuous()
@test varmap[y] == Continuous()
@test varmap[u] == Continuous()

@info "Testing shift normalization"
dt = 0.1
Expand Down Expand Up @@ -192,10 +192,10 @@ eqs = [yd ~ Sample(dt)(y)
@test varmap[ud1] == d
@test varmap[yd2] == d2
@test varmap[ud2] == d2
@test varmap[r] == Continuous
@test varmap[x] == Continuous
@test varmap[y] == Continuous
@test varmap[u] == Continuous
@test varmap[r] == Continuous()
@test varmap[x] == Continuous()
@test varmap[y] == Continuous()
@test varmap[u] == Continuous()

@info "test composed systems"

Expand Down Expand Up @@ -241,14 +241,14 @@ eqs = [yd ~ Sample(dt)(y)
ci, varmap = infer_clocks(cl)

@test varmap[f.x] == Clock(0.5)
@test varmap[p.x] == Continuous
@test varmap[p.y] == Continuous
@test varmap[p.x] == Continuous()
@test varmap[p.y] == Continuous()
@test varmap[c.ud] == Clock(0.5)
@test varmap[c.yd] == Clock(0.5)
@test varmap[c.y] == Continuous
@test varmap[c.y] == Continuous()
@test varmap[f.y] == Clock(0.5)
@test varmap[f.u] == Clock(0.5)
@test varmap[p.u] == Continuous
@test varmap[p.u] == Continuous()
@test varmap[c.r] == Clock(0.5)

## Multiple clock rates
Expand Down Expand Up @@ -474,7 +474,7 @@ eqs = [yd ~ Sample(dt)(y)

## Test continuous clock

c = ModelingToolkit.SolverStepClock
c = ModelingToolkit.SolverStepClock()
k = ShiftIndex(c)

@mtkmodel CounterSys begin
Expand Down
Loading