diff --git a/Project.toml b/Project.toml index ca33faf..9beedfe 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,8 @@ authors = ["Chris Rackauckas "] version = "4.26.1" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -18,6 +20,8 @@ Sundials_jll = "fb77eaff-e24c-56d4-86b1-d163f2edb164" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [compat] +Accessors = "0.1.38" +ArrayInterface = "7.17.1" CEnum = "0.5" DataStructures = "0.18" DiffEqBase = "6.154" @@ -36,8 +40,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" +SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/src/Sundials.jl b/src/Sundials.jl index 90f1a5c..03a6f6c 100644 --- a/src/Sundials.jl +++ b/src/Sundials.jl @@ -5,6 +5,8 @@ module Sundials import Reexport Reexport.@reexport using DiffEqBase using SciMLBase: AbstractSciMLOperator +import Accessors: @reset +import ArrayInterface import SymbolicIndexingInterface as SII import SymbolicIndexingInterface: ParameterIndexingProxy import DataStructures @@ -83,6 +85,7 @@ include("common_interface/verbosity.jl") include("common_interface/algorithms.jl") include("common_interface/integrator_types.jl") include("common_interface/integrator_utils.jl") +include("common_interface/initialize_dae.jl") include("common_interface/solve.jl") import PrecompileTools diff --git a/src/common_interface/initialize_dae.jl b/src/common_interface/initialize_dae.jl new file mode 100644 index 0000000..eb66f17 --- /dev/null +++ b/src/common_interface/initialize_dae.jl @@ -0,0 +1,78 @@ +struct SundialsDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end + +function DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator, initializealg = integrator.initializealg) + _initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + +struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm +end + +function _initialize_dae!(integrator::IDAIntegrator, prob, + initializealg::IDADefaultInit, isinplace) + if integrator.u_modified + IDAReinit!(integrator) + end + integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t) + tstart, tend = integrator.sol.prob.tspan + if any(abs.(integrator.tmp) .>= integrator.opts.reltol) + if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all + error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") + end + if integrator.alg.init_all + init_type = IDA_Y_INIT + else + init_type = IDA_YA_YDP_INIT + integrator.flag = IDASetId(integrator.mem, + vec(integrator.sol.prob.differential_vars)) + end + dt = integrator.dt == tstart ? tend : integrator.dt + integrator.flag = IDACalcIC(integrator.mem, init_type, dt) + + # Reflect consistent initial conditions back into the integrator's + # shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}). + IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec) + end + if integrator.t == tstart && integrator.flag < 0 + integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, + ReturnCode.InitialFailure) + end +end + +function _initialize_dae!(integrator, prob, ::SundialsDefaultInit, isinplace) + if SciMLBase.has_initializeprob(prob.f) + _initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + elseif integrator isa IDAIntegrator + _initialize_dae!(integrator, prob, IDADefaultInit(), isinplace) + end +end + +function _initialize_dae!(integrator, prob, initalg::SciMLBase.NoInit, isinplace) end + +function _initialize_dae!(integrator, prob, initalg::SciMLBase.OverrideInit, isinplace::Union{Val{true}, Val{false}}) + nlsolve_alg = KINSOL() + u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) + + if isinplace === Val{true}() + integrator.u .= u0 + if length(integrator.sol.u) == 1 + integrator.sol.u[1] .= u0 + end + else + integrator.u = u0 + if length(integrator.sol.u) == 1 + integrator.sol.u[1] = u0 + end + end + integrator.p = p + sol = integrator.sol + @reset sol.prob.p = integrator.p + integrator.sol = sol + + if !success + integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.InitialFailure) + end +end + +function _initialize_dae!(integrator, prob, initalg::SciMLBase.CheckInit, isinplace::Union{Val{true}, Val{false}}) + SciMLBase.get_initial_values(prob, integrator, prob.f, initalg, isinplace; abstol = integrator.opts.abstol) +end diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index a1b1605..022fdc8 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -40,7 +40,8 @@ mutable struct CVODEIntegrator{N, oType, LStype, Atype, - CallbackCacheType} <: AbstractSundialsIntegrator{algType} + CallbackCacheType, + IA} <: AbstractSundialsIntegrator{algType} u::Array{Float64, N} u_nvec::NVector p::pType @@ -66,6 +67,7 @@ mutable struct CVODEIntegrator{N, vector_event_last_time::Int callback_cache::CallbackCacheType last_event_error::Float64 + initializealg::IA end function (integrator::CVODEIntegrator)(t::Number, @@ -96,7 +98,8 @@ mutable struct ARKODEIntegrator{N, Atype, MLStype, Mtype, - CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE} + CallbackCacheType, + IA} <: AbstractSundialsIntegrator{ARKODE} u::Array{Float64, N} u_nvec::NVector p::pType @@ -124,6 +127,7 @@ mutable struct ARKODEIntegrator{N, vector_event_last_time::Int callback_cache::CallbackCacheType last_event_error::Float64 + initializealg::IA end function (integrator::ARKODEIntegrator)(t::Number, diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index a5e273f..30c5626 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -187,42 +187,6 @@ end # Required for callbacks DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator, dt) = nothing -DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing - -struct IDADefaultInit <: DiffEqBase.DAEInitializationAlgorithm -end - -function DiffEqBase.initialize_dae!(integrator::IDAIntegrator, - initializealg::IDADefaultInit) - if integrator.u_modified - IDAReinit!(integrator) - end - integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t) - tstart, tend = integrator.sol.prob.tspan - if any(abs.(integrator.tmp) .>= integrator.opts.reltol) - if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all - error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") - end - if integrator.alg.init_all - init_type = IDA_Y_INIT - else - init_type = IDA_YA_YDP_INIT - integrator.flag = IDASetId(integrator.mem, - vec(integrator.sol.prob.differential_vars)) - end - dt = integrator.dt == tstart ? tend : integrator.dt - integrator.flag = IDACalcIC(integrator.mem, init_type, dt) - - # Reflect consistent initial conditions back into the integrator's - # shadow copy. N.B.: ({du, u}_nvec are aliased to {du, u}). - IDAGetConsistentIC(integrator.mem, integrator.u_nvec, integrator.du_nvec) - end - if integrator.t == tstart && integrator.flag < 0 - integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, - ReturnCode.InitialFailure) - end -end - DiffEqBase.has_reinit(integrator::AbstractSundialsIntegrator) = true function DiffEqBase.reinit!(integrator::AbstractSundialsIntegrator, u0 = integrator.sol.prob.u0; diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index a1cb3c0..42b854f 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -124,6 +124,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i stop_at_next_tstop = false, userdata = nothing, alias_u0 = false, + initializealg = SundialsDefaultInit(), kwargs...) where {uType, tupType, isinplace, Method, LinearSolver } tType = eltype(tupType) @@ -457,7 +458,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i 0, 1, callback_cache, - 0.0) + 0.0, + initializealg) + DiffEqBase.initialize_dae!(integrator) initialize_callbacks!(integrator) integrator end # function solve @@ -499,6 +502,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i stop_at_next_tstop = false, userdata = nothing, alias_u0 = false, + initializealg = SundialsDefaultInit(), kwargs...) where {uType, tupType, isinplace, Method, LinearSolver, MassLinearSolver} @@ -945,8 +949,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i 0, 1, callback_cache, - 0.0) + 0.0, + initializealg) + DiffEqBase.initialize_dae!(integrator) initialize_callbacks!(integrator) integrator end # function solve @@ -1010,7 +1016,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu advance_to_tstop = false, stop_at_next_tstop = false, userdata = nothing, - initializealg = IDADefaultInit(), + initializealg = SundialsDefaultInit(), kwargs...) where {uType, duType, tupType, isinplace, LinearSolver } tType = eltype(tupType) @@ -1313,7 +1319,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu dutmp, initializealg) - DiffEqBase.initialize_dae!(integrator, initializealg) + DiffEqBase.initialize_dae!(integrator) integrator.u_modified && IDAReinit!(integrator) if save_start