-
Notifications
You must be signed in to change notification settings - Fork 0
/
trpo.py
84 lines (66 loc) · 2.59 KB
/
trpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import torch
from torch.autograd import Variable
from utils import *
def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
x = torch.zeros(b.size())
r = b.clone()
p = b.clone()
rdotr = torch.dot(r, r)
for i in range(nsteps):
_Avp = Avp(p)
alpha = rdotr / torch.dot(p, _Avp)
x += alpha * p
r -= alpha * _Avp
new_rdotr = torch.dot(r, r)
betta = new_rdotr / rdotr
p = r + betta * p
rdotr = new_rdotr
if rdotr < residual_tol:
break
return x
def linesearch(model,
f,
x,
fullstep,
expected_improve_rate,
max_backtracks=10,
accept_ratio=.1):
fval = f().data
#print("fval before", fval[0])
for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
xnew = x + stepfrac * fullstep
set_flat_params_to(model, xnew)
newfval = f().data
actual_improve = fval - newfval
expected_improve = expected_improve_rate * stepfrac
ratio = actual_improve / expected_improve
#print("a/e/r", actual_improve[0], expected_improve[0], ratio[0])
if ratio.item() > accept_ratio and actual_improve.item() > 0:
#print("fval after", newfval[0])
return True, xnew
return False, x
def trpo_step(model, get_loss, get_kl, max_kl, damping):
loss = get_loss()
grads = torch.autograd.grad(loss, model.parameters())
loss_grad = torch.cat([grad.view(-1) for grad in grads]).data
def Fvp(v):
kl = get_kl()
kl = kl.mean()
grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])
kl_v = (flat_grad_kl * Variable(v)).sum()
grads = torch.autograd.grad(kl_v, model.parameters())
flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data
return flat_grad_grad_kl + v * damping
stepdir = conjugate_gradients(Fvp, -loss_grad, 10)
shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)
lm = torch.sqrt(shs / max_kl) # 1 / beta
fullstep = stepdir / lm[0] # beta * s
neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
#print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))
prev_params = get_flat_params_from(model)
success, new_params = linesearch(model, get_loss, prev_params, fullstep,
neggdotstepdir / lm[0])
set_flat_params_to(model, new_params)
return loss