Skip to content

Commit

Permalink
Merge pull request #2241 from SciML/defaults
Browse files Browse the repository at this point in the history
Fix case of nothing initiailizeprob with non-stiff default alg
  • Loading branch information
ChrisRackauckas authored Jun 8, 2024
2 parents 8db5efe + d86c55b commit 06a9277
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ end
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)

default_nlsolve(alg, isinplace, u, initprob, autodiff = false) = alg

function default_nlsolve(::Nothing, isinplace, u::Nothing, ::NonlinearProblem, autodiff = false)
nothing
end
function default_nlsolve(::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
FastShortcutNonlinearPolyalg(;
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
Expand All @@ -30,6 +34,10 @@ function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end

function default_nlsolve(
::Nothing, isinplace, u::Nothing, ::NonlinearLeastSquaresProblem, autodiff = false)
nothing
end
function default_nlsolve(
::Nothing, isinplace, u, ::NonlinearLeastSquaresProblem, autodiff = false)
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
Expand Down Expand Up @@ -138,8 +146,13 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff :
true
isAD = if initializeprob.u0 === nothing
AutoForwardDiff
elseif has_autodiff(integrator.alg)
alg_autodiff(integrator.alg) isa AutoForwardDiff
else
true
end

alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsol = solve(initializeprob, alg)
Expand Down

0 comments on commit 06a9277

Please sign in to comment.