diff --git a/Project.toml b/Project.toml index bd6935f76f..e4cd9c80a9 100644 --- a/Project.toml +++ b/Project.toml @@ -142,6 +142,7 @@ SparseDiffTools = "2" Static = "0.8, 1" StaticArrayInterface = "1.2" StaticArrays = "1.0" +SymbolicIndexingInterface = "0.3.31" TruncatedStacktraces = "1.2" julia = "1.10" @@ -166,9 +167,10 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve"] +test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "SymbolicIndexingInterface"] diff --git a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl index d769122a84..7e76736f7f 100644 --- a/lib/OrdinaryDiffEqCore/src/initialize_dae.jl +++ b/lib/OrdinaryDiffEqCore/src/initialize_dae.jl @@ -156,12 +156,23 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem, nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) - u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) + if prob isa DAEProblem + du0, u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol, return_du0 = true) + else + u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol) + du0 = nothing + end if isinplace === Val{true}() integrator.u .= u0 + if du0 !== nothing + integrator.du .= du0 + end elseif isinplace === Val{false}() integrator.u = u0 + if du0 !== nothing + integrator.du = du0 + end else error("Unreachable reached. Report this error.") end diff --git a/test/interface/dae_initialization_tests.jl b/test/interface/dae_initialization_tests.jl index e00d8270e5..db3b8c6579 100644 --- a/test/interface/dae_initialization_tests.jl +++ b/test/interface/dae_initialization_tests.jl @@ -114,3 +114,38 @@ prob = ODEProblem(f, ones(3), (0.0, 1.0)) integrator = init(prob, Rodas5P(), initializealg = ShampineCollocationInit(1.0, BrokenNLSolve())) @test all(isequal(reinterpret(Float64, 0xDEADBEEFDEADBEEF)), integrator.u) + +@testset "OverrideInit for DAEProblem" begin + function daerhs(du, u, p, t) + return [u[1] * t + p, u[1]^2 - u[2]^2] + end + # unknowns are u[2], p, D(u[1]), D(u[2]). Parameters are u[1], t + initprob = NonlinearProblem([1.0, 1.0, 1.0, 1.0], [1.0, 0.0]) do x, _p + u2, p, du1, du2 = x + u1, t = _p + return [u1^3 - u2^3, p^2 - 2p + 1, du1 - u1 * t - p, 2u1 * du1 - 2u2 * du2] + end + + update_initializeprob! = function (iprob, integ) + iprob.p[1] = integ.u[1] + iprob.p[2] = integ.t + end + initprobmap = function (nlsol) + return [parameter_values(nlsol)[1], nlsol.u[1]] + end + initprobpmap = function (_, nlsol) + return nlsol.u[2] + end + initprob_du0map = function (nlsol) + return nlsol.u[3:4] + end + initialization_data = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap, initprob_du0map) + fn = DAEFunction(daerhs; initialization_data) + prob = DAEProblem(fn, [0.0, 0.0], [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob, DImplicitEuler()) + @test integ.du ≈ [1.0, 1.0] + @test integ.u ≈ [2.0, 2.0] + @test integ.p ≈ 1.0 + @test integ.sol.retcode != SciMLBase.ReturnCode.InitialFailure +end