Skip to content

Commit

Permalink
feat: add support for OverrideInit and CheckInit
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 21, 2024
1 parent d4caf26 commit de5ced3
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 43 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ authors = ["Chris Rackauckas <[email protected]>"]
version = "4.26.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -18,6 +20,8 @@ Sundials_jll = "fb77eaff-e24c-56d4-86b1-d163f2edb164"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[compat]
Accessors = "0.1.38"
ArrayInterface = "7.17.1"
CEnum = "0.5"
DataStructures = "0.18"
DiffEqBase = "6.154"
Expand All @@ -36,8 +40,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
3 changes: 3 additions & 0 deletions src/Sundials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ module Sundials
import Reexport
Reexport.@reexport using DiffEqBase
using SciMLBase: AbstractSciMLOperator
import Accessors: @reset
import ArrayInterface
import SymbolicIndexingInterface as SII
import SymbolicIndexingInterface: ParameterIndexingProxy
import DataStructures
Expand Down Expand Up @@ -83,6 +85,7 @@ include("common_interface/verbosity.jl")
include("common_interface/algorithms.jl")
include("common_interface/integrator_types.jl")
include("common_interface/integrator_utils.jl")
include("common_interface/initialize_dae.jl")
include("common_interface/solve.jl")

import PrecompileTools
Expand Down
78 changes: 78 additions & 0 deletions src/common_interface/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
struct SundialsDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initializealg = integrator.initializealg)
_initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm
end

function _initialize_dae!(integrator::IDAIntegrator, prob,
initializealg::IDADefaultInit, isinplace)
if integrator.u_modified
IDAReinit!(integrator)
end
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
tstart, tend = integrator.sol.prob.tspan
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
end
if integrator.alg.init_all
init_type = IDA_Y_INIT
else
init_type = IDA_YA_YDP_INIT
integrator.flag = IDASetId(integrator.mem,
vec(integrator.sol.prob.differential_vars))
end
dt = integrator.dt == tstart ? tend : integrator.dt
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)

# Reflect consistent initial conditions back into the integrator's
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
end
if integrator.t == tstart && integrator.flag < 0
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.InitialFailure)
end
end

function _initialize_dae!(integrator, prob, ::SundialsDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
_initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
elseif integrator isa IDAIntegrator
_initialize_dae!(integrator, prob, IDADefaultInit(), isinplace)
end
end

function _initialize_dae!(integrator, prob, initalg::SciMLBase.NoInit, isinplace) end

function _initialize_dae!(integrator, prob, initalg::SciMLBase.OverrideInit, isinplace::Union{Val{true}, Val{false}})
nlsolve_alg = KINSOL()
u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)

if isinplace === Val{true}()
integrator.u .= u0
if length(integrator.sol.u) == 1
integrator.sol.u[1] .= u0
end
else
integrator.u = u0
if length(integrator.sol.u) == 1
integrator.sol.u[1] = u0
end
end
integrator.p = p
sol = integrator.sol
@reset sol.prob.p = integrator.p
integrator.sol = sol

if !success
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure)
end
end

function _initialize_dae!(integrator, prob, initalg::SciMLBase.CheckInit, isinplace::Union{Val{true}, Val{false}})
SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; abstol = integrator.opts.abstol)
end
8 changes: 6 additions & 2 deletions src/common_interface/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ mutable struct CVODEIntegrator{N,
oType,
LStype,
Atype,
CallbackCacheType} <: AbstractSundialsIntegrator{algType}
CallbackCacheType,
IA} <: AbstractSundialsIntegrator{algType}
u::Array{Float64, N}
u_nvec::NVector
p::pType
Expand All @@ -66,6 +67,7 @@ mutable struct CVODEIntegrator{N,
vector_event_last_time::Int
callback_cache::CallbackCacheType
last_event_error::Float64
initializealg::IA
end

function (integrator::CVODEIntegrator)(t::Number,
Expand Down Expand Up @@ -96,7 +98,8 @@ mutable struct ARKODEIntegrator{N,
Atype,
MLStype,
Mtype,
CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE}
CallbackCacheType,
IA} <: AbstractSundialsIntegrator{ARKODE}
u::Array{Float64, N}
u_nvec::NVector
p::pType
Expand Down Expand Up @@ -124,6 +127,7 @@ mutable struct ARKODEIntegrator{N,
vector_event_last_time::Int
callback_cache::CallbackCacheType
last_event_error::Float64
initializealg::IA
end

function (integrator::ARKODEIntegrator)(t::Number,
Expand Down
36 changes: 0 additions & 36 deletions src/common_interface/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,42 +187,6 @@ end
# Required for callbacks
DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator, dt) = nothing

DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing

struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm
end

function DiffEqBase.initialize_dae!(integrator::IDAIntegrator,
initializealg::IDADefaultInit)
if integrator.u_modified
IDAReinit!(integrator)
end
integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t)
tstart, tend = integrator.sol.prob.tspan
if any(abs.(integrator.tmp) .>= integrator.opts.reltol)
if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all
error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.")
end
if integrator.alg.init_all
init_type = IDA_Y_INIT
else
init_type = IDA_YA_YDP_INIT
integrator.flag = IDASetId(integrator.mem,
vec(integrator.sol.prob.differential_vars))
end
dt = integrator.dt == tstart ? tend : integrator.dt
integrator.flag = IDACalcIC(integrator.mem, init_type, dt)

# Reflect consistent initial conditions back into the integrator's
# shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}).
IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec)
end
if integrator.t == tstart && integrator.flag < 0
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
ReturnCode.InitialFailure)
end
end

DiffEqBase.has_reinit(integrator::AbstractSundialsIntegrator) = true
function DiffEqBase.reinit!(integrator::AbstractSundialsIntegrator,
u0 = integrator.sol.prob.u0;
Expand Down
14 changes: 10 additions & 4 deletions src/common_interface/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
stop_at_next_tstop = false,
userdata = nothing,
alias_u0 = false,
initializealg = SundialsDefaultInit(),
kwargs...) where {uType, tupType, isinplace, Method, LinearSolver
}
tType = eltype(tupType)
Expand Down Expand Up @@ -457,7 +458,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
0,
1,
callback_cache,
0.0)
0.0,
initializealg)
DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator)
integrator
end # function solve
Expand Down Expand Up @@ -499,6 +502,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
stop_at_next_tstop = false,
userdata = nothing,
alias_u0 = false,
initializealg = SundialsDefaultInit(),
kwargs...) where {uType, tupType, isinplace, Method,
LinearSolver,
MassLinearSolver}
Expand Down Expand Up @@ -945,8 +949,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
0,
1,
callback_cache,
0.0)
0.0,
initializealg)

DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator)
integrator
end # function solve
Expand Down Expand Up @@ -1010,7 +1016,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu
advance_to_tstop = false,
stop_at_next_tstop = false,
userdata = nothing,
initializealg = IDADefaultInit(),
initializealg = SundialsDefaultInit(),
kwargs...) where {uType, duType, tupType, isinplace, LinearSolver
}
tType = eltype(tupType)
Expand Down Expand Up @@ -1313,7 +1319,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu
dutmp,
initializealg)

DiffEqBase.initialize_dae!(integrator, initializealg)
DiffEqBase.initialize_dae!(integrator)
integrator.u_modified && IDAReinit!(integrator)

if save_start
Expand Down

0 comments on commit de5ced3

Please sign in to comment.