Skip to content

Commit

Permalink
Merge pull request #7 from JuliaDiffEq/userchain
Browse files Browse the repository at this point in the history
use Chain from user
  • Loading branch information
ChrisRackauckas authored Jun 19, 2019
2 parents 710039f + ecaaa39 commit b1c49a1
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ os:
- linux
- osx
julia:
- 1.0
- 1.1
- nightly
matrix:
allow_failures:
Expand Down
8 changes: 0 additions & 8 deletions REQUIRE

This file was deleted.

8 changes: 4 additions & 4 deletions src/NeuralNetDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ using Reexport
using Flux

abstract type NeuralNetDiffEqAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
struct nnode <: NeuralNetDiffEqAlgorithm
hl_width::Int
struct nnode{C,O} <: NeuralNetDiffEqAlgorithm
chain::C
opt::O
end
nnode(;hl_width=10) = nnode(hl_width)
nnode(chain;opt=Flux.ADAM(0.1)) = nnode(chain,opt)
export nnode

include("solve.jl")
include("training_utils.jl")

end # module
74 changes: 34 additions & 40 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,60 @@ function DiffEqBase.solve(
prob::DiffEqBase.AbstractODEProblem,
alg::NeuralNetDiffEqAlgorithm,
args...;
dt = error("dt must be set."),
dt,
timeseries_errors = true,
save_everystep=true,
adaptive=false,
abstol = 1f-6,
verbose = false,
maxiters = 100)

DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!")

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p
t0 = tspan[1]

#types and dimensions
# uElType = eltype(u0)
# tType = typeof(tspan[1])
# outdim = length(u0)

#hidden layer
hl_width = alg.hl_width

#initialization of weights and bias
P = init_params(hl_width)

#The phi trial solution
phi(P,x) = u0 .+ x.*predict(P,x)
chain = alg.chain
opt = alg.opt
ps = Flux.params(chain)
data = Iterators.repeated((), maxiters)

#train points generation
x = generate_data(tspan[1],tspan[2],dt)
y = [f(phi(P, i)[1].data, p, i) for i in x]
px =Flux.param(x)
data = [(px, y)]
ts = tspan[1]:dt:tspan[2]

#initialization of optimization parameters (ADAM by default for now)
η = 0.1
β1 = 0.9
β2 = 0.95
opt = Flux.ADAM(η, (β1, β2))

ps = Flux.Params(P)

#derivatives of a function f
dfdx(i) = Tracker.gradient(() -> sum(phi(P,i)), Flux.params(i); nest = true)
#loss function for training
loss(x, y) = sum(abs2, [dfdx(i)[i] for i in x] .- y)
#The phi trial solution
phi(t) = u0 .+ (t .- tspan[1]).*chain(Tracker.collect([t]))

if u0 isa Number
dfdx = t -> Tracker.gradient(t -> sum(phi(t)), t; nest = true)[1]
loss = () -> sum(abs2,sum(abs2,dfdx(t) .- f(phi(t)[1],p,t)[1]) for t in ts)
else
dfdx = t -> (phi(t+sqrt(eps(typeof(dt)))) - phi(t)) / sqrt(eps(typeof(dt)))
#dfdx(t) = Flux.Tracker.forwarddiff(phi,t)
#dfdx(t) = Tracker.collect([Flux.Tracker.gradient(t->phi(t)[i],t, nest=true) for i in 1:length(u0)])
#loss function for training
loss = () -> sum(abs2,sum(abs2,dfdx(t) - f(phi(t),p,t)) for t in ts)
end

@time for iters=1:maxiters
Flux.train!(loss, ps, data, opt)
if mod(iters,50) == 0
loss_value = loss(px,y).data
println((:iteration,iters,:loss,loss_value))
if loss_value < 10^(-6.0)
break
end
end
cb = function ()
l = loss()
verbose && println("Current loss is: $l")
l < abstol && Flux.stop()
end
Flux.train!(loss, ps, data, opt; cb = cb)

#solutions at timepoints
u = [phi(P,i)[1].data for i in x]
if u0 isa Number
u = [phi(t)[1].data for t in ts]
else
u = [phi(t).data for t in ts]
end

sol = DiffEqBase.build_solution(prob,alg,x,u,calculate_error = false)
sol = DiffEqBase.build_solution(prob,alg,ts,u,calculate_error = false)
DiffEqBase.has_analytic(prob.f) && DiffEqBase.calculate_solution_errors!(sol;timeseries_errors=true,dense_errors=false)
sol
end #solve
16 changes: 0 additions & 16 deletions src/training_utils.jl

This file was deleted.

2 changes: 0 additions & 2 deletions test/REQUIRE

This file was deleted.

63 changes: 46 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,58 @@
using NeuralNetDiffEq, Test
using Test, Flux, NeuralNetDiffEq
using DiffEqDevTools

# Run a solve
# Run a solve on scalars
linear = (u,p,t) -> cos(2pi*t)
tspan = (0.0,1.0)
u0 = 0.0
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0 ,tspan)
sol = solve(prob, NeuralNetDiffEq.nnode(5), dt=1/20, maxiters=300)
# println(sol)
#plot(sol)
#plot!(sol.t, t -> sin(2pi*t) / (2*pi), lw=3,ls=:dash,label="True Solution!")
chain = Flux.Chain(Dense(1,5,σ),Dense(5,1))
opt = Flux.ADAM(0.1, (0.9, 0.95))
sol = solve(prob, NeuralNetDiffEq.nnode(chain,opt), dt=1/20f0, verbose = true,
abstol=1e-10, maxiters = 200)

# Run a solve on vectors
linear = (u,p,t) -> [cos(2pi*t)]
tspan = (0.0f0, 1.0f0)
u0 = [0.0f0]
prob = ODEProblem(linear, u0 ,tspan)
chain = Flux.Chain(Dense(1,5,σ),Dense(5,1))
opt = Flux.ADAM(0.1, (0.9, 0.95))
sol = solve(prob, NeuralNetDiffEq.nnode(chain,opt), dt=1/20f0, abstol=1e-10,
verbose = true, maxiters=200)

#Example 1
linear = (u,p,t) -> t^3 + 2*t + (t^2)*((1+3*(t^2))/(1+t+(t^3))) - u*(t + ((1+3*(t^2))/(1+t+t^3)))
linear_analytic = (u0,p,t) -> exp(-(t^2)/2)/(1+t+t^3) + t^2
prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),1/2,(0.0,1.0))
dts = 1 ./ 2 .^ (10:-1:7)
sim = test_convergence(dts, prob, nnode())
@test abs(sim.𝒪est[:l2]) < 0.3
linear = (u,p,t) -> @. t^3 + 2*t + (t^2)*((1+3*(t^2))/(1+t+(t^3))) - u*(t + ((1+3*(t^2))/(1+t+t^3)))
linear_analytic = (u0,p,t) -> [exp(-(t^2)/2)/(1+t+t^3) + t^2]
prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),[1f0],(0.0f0,1.0f0))
chain = Flux.Chain(Dense(1,5,σ),Dense(5,1))
opt = Flux.ADAM(0.1, (0.9, 0.95))
sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/5f0)
err = sol.errors[:l2]
sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/20f0)
sol.errors[:l2]/err < 0.5

#=
dts = 1f0 ./ 2f0 .^ (6:-1:2)
sim = test_convergence(dts, prob, NeuralNetDiffEq.nnode(chain, opt))
@test abs(sim.𝒪est[:l2]) < 0.1
@test minimum(sim.errors[:l2]) < 0.5
=#

#Example 2
linear = (u,p,t) -> -u/5 + exp(-t/5).*cos(t)
linear_analytic = (u0,p,t) -> exp(-t/5)*(u0 + sin(t))
prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),0.0,(0.0,1.0))
sim = test_convergence(dts, prob, nnode())
prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),0.0f0,(0.0f0,1.0f0))
chain = Flux.Chain(Dense(1,5,σ),Dense(5,1))
opt = Flux.ADAM(0.1, (0.9, 0.95))
sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/5f0)
err = sol.errors[:l2]
sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/20f0)
sol.errors[:l2]/err < 0.5

#=
dts = 1f0 ./ 2f0 .^ (6:-1:2)
sim = test_convergence(dts, prob, NeuralNetDiffEq.nnode(chain, opt))
@test abs(sim.𝒪est[:l2]) < 0.5
@test minimum(sim.errors[:l2]) < 0.3
@test minimum(sim.errors[:l2]) < 0.1
=#

0 comments on commit b1c49a1

Please sign in to comment.