Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remake ignores u0 updates when the initialization is trivial #3318

Open
SebastianM-C opened this issue Jan 14, 2025 · 2 comments
Open

remake ignores u0 updates when the initialization is trivial #3318

SebastianM-C opened this issue Jan 14, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@SebastianM-C
Copy link
Contributor

SebastianM-C commented Jan 14, 2025

Describe the bug 🐞

Edit: The issue is more general that I initially described, the fact that we have dual numbers is not relevant, see the next post for a simpler MWE.

Passing Vector{ForwardDiff.Dual} to remake leads to an internal type promotion by updated_u0_p, but by the end of the function the values are back to floats. Furthermore, if this is an ODEProblem, init also also promotes the u0 internally, but it then uses remake to update the values in the problem and the duals are dropped again, leading to method errors in the problem function.

Expected behavior

remake should not drop the duals.

Minimal Reproducible Example 👇

using ModelingToolkit
using ModelingToolkit: D_nounits, t_nounits as t
using DifferentiationInterface
using OrdinaryDiffEqDefault
using SciMLBase
using SymbolicIndexingInterface

ps = @parameters k1 = 1.0 c1 = missing [guess = 2] c1_cond1 = 2.0
sts = @variables s1(t) = 2.0 s1s2(t) = 2.0 s2(t) = 2.0
eqs = [D_nounits(s1) ~ -0.25 * c1 * k1 * s1 * s2
    D_nounits(s1s2) ~ 0.25 * c1 * k1 * s1 * s2
    D_nounits(s2) ~ -0.25 * c1 * k1 * s1 * s2]

model = structural_simplify(ODESystem(eqs,
    t,
    sts,
    ps;
    name=:reactionsystem))

prob0 = ODEProblem{true, SciMLBase.FullSpecialize}(model, [], (0., 1), [c1 => c1_cond1])

setk = setsym_oop(prob0, [c1_cond1])

function foo2(x, prob0)
    (u0, p) = setk(prob0, (x))
    prob1 = remake(prob0; u0, p)
    # sol = solve(prob1)
    integ = init(prob1, DefaultODEAlgorithm())
    @info integ.u
    sum(integ.u)
end

foo2([3], prob0)

gradient(x->foo2(x, prob0), AutoForwardDiff(), [3.])

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float64}, Float64, 1})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:265
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float64(::IrrationalConstants.Logπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/vp5v4/src/macro.jl:112
  ...

Stacktrace:
  [1] convert(::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float64}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float64}, Float64, 1}, i::Int64)
    @ Base ./array.jl:987
  [3] macro expansion
    @ ~/.julia/packages/SymbolicUtils/bpwpv/src/code.jl:433 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Symbolics/PxO3a/src/build_function.jl:558 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/SymbolicUtils/bpwpv/src/code.jl:390 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Symbolics/PxO3a/src/build_function.jl:342 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined]
  [8] macro expansion
    @ ./none:0 [inlined]
  [9] generated_callfunc
    @ ./none:0 [inlined]
 [10] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::Vector{…}, ::Vector{…}, ::Vector{…}, ::Float64)
    @ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150
 [11] (::ModelingToolkit.var"#f#1060"{})(du::Vector{…}, u::Vector{…}, p::MTKParameters{…}, t::Float64)
    @ ModelingToolkit ~/.julia/dev/ModelingToolkit/src/systems/diffeqs/abstractodesystem.jl:379
 [12] (::ODEFunction{…})(::Vector{…}, ::Vararg{…})
    @ SciMLBase ~/.julia/dev/SciMLBase/src/scimlfunctions.jl:2468
 [13] initialize!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, cache::OrdinaryDiffEqTsit5.Tsit5Cache{…})
    @ OrdinaryDiffEqTsit5 ~/.julia/packages/OrdinaryDiffEqTsit5/DHYtz/src/tsit_perform_step.jl:175
 [14] initialize!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, cache::OrdinaryDiffEqCore.DefaultCache{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/3Talm/src/perform_step/composite_perform_step.jl:38
 [15] __init(prob::ODEProblem{…}, alg::CompositeAlgorithm{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::ODEAliasSpecifier, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/3Talm/src/solve.jl:582
 [16] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/3Talm/src/solve.jl:11 [inlined]
 [17] #init_call#40
    @ ~/.julia/packages/DiffEqBase/R2Vjs/src/solve.jl:545 [inlined]
 [18] init_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/R2Vjs/src/solve.jl:518
 [19] init_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::MTKParameters{…}, args::CompositeAlgorithm{…}; kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/R2Vjs/src/solve.jl:586
 [20] init_up
    @ ~/.julia/packages/DiffEqBase/R2Vjs/src/solve.jl:566 [inlined]
 [21] #init#41
    @ ~/.julia/packages/DiffEqBase/R2Vjs/src/solve.jl:559 [inlined]
 [22] init(prob::ODEProblem{…}, args::CompositeAlgorithm{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/R2Vjs/src/solve.jl:549
 [23] foo2(x::Vector{…}, prob0::ODEProblem{…})
    @ Main ~/juliasim/dev/remake_dual_drop.jl:28
 [24] (::var"#19#20")(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float64}, Float64, 1}})
    @ Main ~/juliasim/dev/remake_dual_drop.jl:35
 [25] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/UBbGT/src/apiutils.jl:24 [inlined]
 [26] vector_mode_gradient(f::var"#19#20", x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/gradient.jl:91
 [27] gradient
    @ ~/.julia/packages/ForwardDiff/UBbGT/src/gradient.jl:20 [inlined]
 [28] gradient(f::var"#19#20", x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{…}, Float64, 1, Vector{…}})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/gradient.jl:17
 [29] gradient(f::var"#19#20", x::Vector{Float64})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/UBbGT/src/gradient.jl:17
 [30] gradient(::var"#19#20", ::AutoForwardDiff{nothing, Nothing}, ::Vector{Float64})
    @ DifferentiationInterfaceForwardDiffExt ~/.julia/packages/DifferentiationInterface/6QHLL/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl:295
 [31] top-level scope
    @ ~/dev/remake_dual_drop.jl:35
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
  [a0c0ee7d] DifferentiationInterface v0.6.28
  [961ee093] ModelingToolkit v9.60.0
  [50262376] OrdinaryDiffEqDefault v1.2.0
  [0bca4576] SciMLBase v2.70.0
  [2efcf032] SymbolicIndexingInterface v0.3.37
  • Output of using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
  • Output of versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × Intel(R) Core(TM) i9-14900K
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, alderlake)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code

Additional context

I can reproduce this with SCCNonlinearProblem too, but it needs SciML/SciMLBase.jl#904

using ModelingToolkit
using NonlinearSolve, SCCNonlinearSolve
using OrdinaryDiffEq
using SciMLBase, Symbolics
using LinearAlgebra, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

function f!(du, u, (p1, p2), t)
    x = (*)(p1[4], u[1])
    y = (*)(p1[4], (+)(0.1016, (*)(-1, u[1])))
    z1 = ifelse((<)(p2[1], 0),
        (*)((*)(457896.07999999996, p1[2]), sqrt((*)(1.1686468413521012e-5, p1[3]))),
        0)
    z2 = ifelse((>)(p2[1], 0),
        (*)((*)((*)(0.58, p1[2]), sqrt((*)(1 // 86100, p1[3]))), u[4]),
        0)
    z3 = ifelse((>)(p2[1], 0),
        (*)((*)(457896.07999999996, p1[2]), sqrt((*)(1.1686468413521012e-5, p1[3]))),
        0)
    z4 = ifelse((<)(p2[1], 0),
        (*)((*)((*)(0.58, p1[2]), sqrt((*)(1 // 86100, p1[3]))), u[5]),
        0)
    du[1] = p2[1]
    du[2] = (+)(z1, (*)(-1, z2))
    du[3] = (+)(z3, (*)(-1, z4))
    du[4] = (+)((*)(-1, u[2]), (*)((*)(1 // 86100, y), u[4]))
    du[5] = (+)((*)(-1, u[3]), (*)((*)(1 // 86100, x), u[5]))
end
p = (
    [0.04864391799335977, 7.853981633974484e-5, 1.4034843205574914,
        0.018241469247509915, 300237.05, 9.226186337232914],
    [0.0508])
u0 = [0.0, 0.0, 0.0, 789476.0, 101325.0]
tspan = (0.0, 1.0)
mass_matrix = [1.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0;
    0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0]
dt = 1e-3
function nlf(u1, (u0, p))
    resid = Any[0 for _ in u0]
    f!(resid, u1, p, 0.0)
    return mass_matrix * (u1 - u0) - dt * resid
end

prob = NonlinearProblem(nlf, u0, (u0, p))
@test_throws Exception solve(prob, SimpleNewtonRaphson(), abstol=1e-9)
sol = solve(prob, TrustRegion(); abstol=1e-9)

@variables u[1:5] [irreducible = true]
@parameters p1[1:6] p2
eqs = 0 .~ collect(nlf(u, (u0, (p1, p2))))
@mtkbuild sys = NonlinearSystem(eqs, [u], [p1, p2])
sccprob = SCCNonlinearProblem(sys, [u => u0], [p1 => p[1], p2 => p[2][]])
sccsol = solve(sccprob, SimpleNewtonRaphson(); abstol=1e-9)
@test SciMLBase.successful_retcode(sccsol)
@test norm(sccsol.resid) < norm(sol.resid)

# the above is taken from the scc testsuite

using SymbolicIndexingInterface

setter = setsym_oop(sccprob, [p1[1]])

function loss(x, (prob, setter))
    (u0, p) = setter(prob, x)
    newprob = remake(prob; u0, p)
    sum(solve(newprob, SimpleNewtonRaphson(); abstol=1e-9))
end

x0 = [sccsol.ps[p2]]
loss(x0, (sccprob, setter))

using DifferentiationInterface
using ForwardDiff

gradient(x->loss(x, (sccprob, setter)), AutoForwardDiff(), x0)
@SebastianM-C SebastianM-C added the bug Something isn't working label Jan 14, 2025
@SebastianM-C
Copy link
Contributor Author

I have an even simpler MWE:

ps = @parameters k = missing [guess = 2]
sts = @variables x(t) = 1.5

eqs = [
    D(x) ~ -k * x
]
model = structural_simplify(ODESystem(eqs, t, sts, ps; name=:model, initialization_eqs=[k ~ 2]))

prob0 = ODEProblem(model, [], (0., 1))
setk = setsym_oop(prob0, [x])

(u0, p) = setk(prob0, [1.])
prob1 = remake(prob0; u0, p)
integ = init(prob1, DefaultODEAlgorithm())

integ.u == [1]

The fact that we had duals was not relevant, simply any update to u0 is ignored.

@SebastianM-C SebastianM-C changed the title remake with dual numbers resets values to floats when the initialization is trivial remake ignores u0 updates when the initialization is trivial Jan 14, 2025
@AayushSabharwal
Copy link
Member

In terms of the simple MWE, this is because x has a default of 1.5 and thus x ~ 1.5 is an equation in the initialization system. Unfortunately, this means that even if you remove the default since the system has an initialization equation it will always build an initializeprob and whatever initial value you provide for x will be an equation. The "solution" is to give x an initial value of parameter in the model:

@parameters x0
model = structural_simplify(ODESystem(eqs, t, sts, [ps..., x0]; name=:model, defaults = [x => x0], initialization_eqs=[k ~ 2]))
prob0 = ODEProblem(model, [], (0., 1), [x0 => 1.5])

setk = setsym_oop(prob0, [x0]) # note: x0 not x

(u0, p) = setk(prob0, [1.])
prob1 = remake(prob0; u0, p)

integ = init(prob1, DefaultODEAlgorithm())
integ.u[1] == 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants