diff --git a/lib/OrdinaryDiffEqCore/Project.toml b/lib/OrdinaryDiffEqCore/Project.toml index 063b6b7819..f3709df99d 100644 --- a/lib/OrdinaryDiffEqCore/Project.toml +++ b/lib/OrdinaryDiffEqCore/Project.toml @@ -49,7 +49,7 @@ Accessors = "0.1.36" Adapt = "3.0, 4" ArrayInterface = "7" DataStructures = "0.18" -DiffEqBase = "6.159" +DiffEqBase = "6.160" DiffEqDevTools = "2.44.4" DocStringExtensions = "0.9" EnumX = "1" @@ -70,7 +70,7 @@ Random = "<0.0.1, 1" RecursiveArrayTools = "2.36, 3" Reexport = "1.0" SafeTestsets = "0.1.0" -SciMLBase = "2.59.2" +SciMLBase = "2.60" SciMLOperators = "0.3" SciMLStructures = "1" SimpleUnPack = "1" diff --git a/lib/OrdinaryDiffEqCore/src/alg_utils.jl b/lib/OrdinaryDiffEqCore/src/alg_utils.jl index 5d144fcfae..d5f89349d6 100644 --- a/lib/OrdinaryDiffEqCore/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqCore/src/alg_utils.jl @@ -17,6 +17,9 @@ end SciMLBase.forwarddiffs_model_time(alg::RosenbrockAlgorithm) = true +SciMLBase.allows_late_binding_tstops(::OrdinaryDiffEqAlgorithm) = true +SciMLBase.allows_late_binding_tstops(::DAEAlgorithm) = true + # isadaptive is defined below. ## OrdinaryDiffEq Internal Traits diff --git a/lib/OrdinaryDiffEqCore/src/solve.jl b/lib/OrdinaryDiffEqCore/src/solve.jl index f2c9e2e087..42fd07a5bf 100644 --- a/lib/OrdinaryDiffEqCore/src/solve.jl +++ b/lib/OrdinaryDiffEqCore/src/solve.jl @@ -242,6 +242,12 @@ function DiffEqBase.__init( resType = typeof(res_prototype) end + if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number + _tstops = nothing + else + _tstops = tstops + tstops = () + end tstops_internal = initialize_tstops(tType, tstops, d_discontinuities, tspan) saveat_internal = initialize_saveat(tType, saveat, tspan) d_discontinuities_internal = initialize_d_discontinuities(tType, d_discontinuities, @@ -542,6 +548,13 @@ function DiffEqBase.__init( end end + if _tstops !== nothing + tstops = _tstops(parameter_values(integrator), prob.tspan) + for tstop in tstops + add_tstop!(integrator, tstop) + end + end + handle_dt!(integrator) integrator end diff --git a/test/interface/ode_tstops_tests.jl b/test/interface/ode_tstops_tests.jl index ce85f0e859..a911ecfe8a 100644 --- a/test/interface/ode_tstops_tests.jl +++ b/test/interface/ode_tstops_tests.jl @@ -76,3 +76,15 @@ end prob = ODEProblem(ff, [0.0], (0.0f0, 1.0f0)) sol = solve(prob, Tsit5(), tstops = [tval], callback = cb) end + +@testset "Late binding tstops" begin + function rhs(u, p, t) + u * p + t + end + prob = ODEProblem(rhs, 1.0, (0.0, 1.0), 0.1; tstops = (p, tspan) -> tspan[1]:p:tspan[2]) + sol = solve(prob, Tsit5()) + @test 0.0:0.1:1.0 ⊆ sol.t + prob2 = remake(prob; p = 0.07) + sol2 = solve(prob2, Tsit5()) + @test 0.0:0.07:1.0 ⊆ sol2.t +end