Skip to content

Commit

Permalink
Merge pull request #800 from oscardssmith/os/ode-nlfunc-support
Browse files Browse the repository at this point in the history
Add nlprob to ODEFunction
  • Loading branch information
ChrisRackauckas authored Oct 30, 2024
2 parents 4be9585 + a91d8b3 commit c341819
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp`
where the nonlinear parameters are the tuple `(t, u_tmp, p)`.
This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t`
such that solving this function produces a solution to the implicit step of your solver.
## iip: In-Place vs Out-Of-Place
Expand Down Expand Up @@ -401,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -423,6 +427,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlprob::NLP
end

@doc doc"""
Expand Down Expand Up @@ -525,8 +530,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -549,6 +554,7 @@ struct SplitFunction{
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlprob::NLP
end

@doc doc"""
Expand Down Expand Up @@ -2432,7 +2438,8 @@ function ODEFunction{iip, specialize}(f;
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
) where {iip,
specialize
}
Expand Down Expand Up @@ -2490,11 +2497,11 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlprob)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2503,13 +2510,16 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
typeof(sys), typeof(initializeprob),
typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap),
typeof(nlprob)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
observed, _colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap, nlprob)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2520,11 +2530,12 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobpmap),
typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlprob)
end
end

Expand All @@ -2541,13 +2552,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any, Any, Any}(
typeof(f.sys), Any, Any, Any, Any, Any}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.update_initializeprob!, f.initializeprobmap,
f.initializeprobpmap)
f.initializeprobpmap, f.nlprob)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
Expand All @@ -2557,11 +2568,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
typeof(f.initializeprobmap),
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.initializeprobpmap),
typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
f.initializeprobmap, f.initializeprobpmap)
f.initializeprobmap, f.initializeprobpmap, f.nlprob)
end
end

Expand Down Expand Up @@ -2693,7 +2705,7 @@ end
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)
initializeprobmap, initializeprobpmap, nlprob)
f1 = ODEFunction(f1)
f2 = ODEFunction(f2)

Expand All @@ -2708,11 +2720,11 @@ end
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
typeof(initializeprobpmap)}(
typeof(initializeprobpmap), typeof(nlprob)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
end
function SplitFunction{iip, specialize}(f1, f2;
mass_matrix = __has_mass_matrix(f1) ?
Expand Down Expand Up @@ -2748,7 +2760,8 @@ function SplitFunction{iip, specialize}(f1, f2;
update_initializeprob! = __has_update_initializeprob!(f1) ?
f1.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing
) where {iip,
specialize
}
Expand All @@ -2759,12 +2772,12 @@ function SplitFunction{iip, specialize}(f1, f2;
if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
initializeprobpmap, initializeprobpmap)
initializeprobpmap, initializeprobpmap, nlprob)
else
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
typeof(_func_cache), typeof(analytic),
Expand All @@ -2774,11 +2787,11 @@ function SplitFunction{iip, specialize}(f1, f2;
typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(f1, f2,
typeof(initializeprobpmap), typeof(nlprob)}(f1, f2,
mass_matrix, _func_cache, analytic, tgrad, jac,
jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
end
end

Expand Down Expand Up @@ -3121,7 +3134,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f

@add_kwonly function SplitSDEFunction(f1, f2, g, mass_matrix, cache, analytic, tgrad, jac,
jvp, vjp,
jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed,
jac_prototype, Wfact, Wfact_t, paramjac, observed,
colorvec, sys)
f1 = f1 isa AbstractSciMLOperator ? f1 : SDEFunction(f1)
f2 = SDEFunction(f2)
Expand All @@ -3132,7 +3145,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(colorvec),
typeof(sys)}(f1, f2, mass_matrix, cache, analytic, tgrad, jac,
jac_prototype, W_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys)
jac_prototype, Wfact, Wfact_t, paramjac, observed, colorvec, sys)
end

function SplitSDEFunction{iip, specialize}(f1, f2, g;
Expand Down Expand Up @@ -4411,6 +4424,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
__has_nlprob(f) = isdefined(f, :nlprob)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand Down

0 comments on commit c341819

Please sign in to comment.