-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_diffusion_samplers.py
139 lines (121 loc) · 4.77 KB
/
test_diffusion_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import tensorflow as tf
import pytest
from NN.restorators.diffusion.diffusion_samplers import sampler_from_config
from NN.restorators.diffusion.diffusion_schedulers import CDPDiscrete, get_beta_schedule
def _fake_model(noise_steps):
x = tf.random.normal([32, 3])
fakeNoise = tf.random.normal([32, 3])
def fakeModel(V, T):
# check range of t
tf.debugging.assert_less_equal(T, 1.0)
tf.debugging.assert_greater_equal(T, 0.0)
# any noticable perturbation will lead to different samples
return fakeNoise + (T + 1) * V
return { 'x': x, 'fakeModel': fakeModel, 'fakeNoise': fakeNoise }
def _fake_DDIM(stochasticity, K, noiseProjection=False):
return sampler_from_config({
'name': 'DDIM',
'stochasticity': stochasticity,
'noise stddev': 'zero',
'steps skip type': { 'name': 'uniform', 'K': K },
'project noise': noiseProjection,
})
# test that DDPM and DDIM return same samples, when:
# - noise is zero/same
# - K=1, stochasticity=1.0
def _fake_samplers():
schedule = CDPDiscrete( beta_schedule=get_beta_schedule('linear'), noise_steps=10 )
ddim = _fake_DDIM(stochasticity=1.0, K=1)
ddpm = sampler_from_config({ 'name': 'DDPM', 'noise stddev': 'zero', })
model = _fake_model(schedule.noise_steps)
return { 'ddim': ddim, 'ddpm': ddpm, 'schedule': schedule, **model }
def test_DDPM_eq_DDIM_steps():
samplers = _fake_samplers()
ddim = samplers['ddim']
ddpm = samplers['ddpm']
schedule = samplers['schedule']
x = samplers['x']
fakeModel = samplers['fakeModel']
########
ddimStepF = ddim._reverseStep(fakeModel, schedule=schedule, eta=1.0)
ddpmStepF = ddpm._reverseStep(fakeModel, schedule=schedule)
for T in reversed(range(schedule.noise_steps)):
t = tf.fill((32, 1), T)
ddimS = ddimStepF(x=x, t=t, tPrev=t - 1)
X_ddpm, var_ddpm = ddpmStepF(x=x, t=t)
s = schedule.parametersForT(T)
tf.print(T, s.alphaHat, s.sigma)
tf.debugging.assert_near(ddimS.x_prev, X_ddpm, atol=1e-5, message=f"t={T}")
tf.debugging.assert_near(ddimS.sigma, var_ddpm, atol=1e-6, message=f"t={T}")
if 0 < T:
tf.debugging.assert_greater(ddimS.sigma, 0.0, message=f"t={T}")
tf.debugging.assert_greater(var_ddpm, 0.0, message=f"t={T}")
continue
# last step should always have zero variance
tf.assert_equal(ddimS.sigma, 0.0)
tf.assert_equal(var_ddpm, 0.0)
return
def test_DDPM_eq_DDIM_sample():
samplers = _fake_samplers()
ddim = samplers['ddim']
ddpm = samplers['ddpm']
schedule = samplers['schedule']
x = samplers['x']
fakeModel = samplers['fakeModel']
########
X_ddim = ddim.sample(value=x, model=fakeModel, schedule=schedule)
X_ddpm = ddpm.sample(value=x, model=fakeModel, schedule=schedule)
tf.debugging.assert_near(X_ddim, X_ddpm, atol=1e-6)
return
def test_DDPM_eq_DDIM_sample_modelCalls():
samplers = _fake_samplers()
ddim = samplers['ddim']
ddpm = samplers['ddpm']
schedule = samplers['schedule']
x = samplers['x']
fakeModel = samplers['fakeModel']
def makeCounter():
def counter(*args, **kwargs):
counter.calls.assign_add(1)
return fakeModel(*args, **kwargs)
counter.calls = tf.Variable(0, dtype=tf.int32)
return counter
fakeModelA = makeCounter()
fakeModelB = makeCounter()
########
_ = ddpm.sample(value=x, model=fakeModelA, schedule=schedule)
_ = ddim.sample(value=x, model=fakeModelB, schedule=schedule)
tf.assert_equal(fakeModelB.calls, fakeModelA.calls)
tf.assert_equal(fakeModelA.calls, schedule.noise_steps)
tf.assert_equal(fakeModelB.calls, schedule.noise_steps)
return
# test that noise projection does not change if noise is zero
@pytest.mark.parametrize(
'stochasticity,K',
[
(1.0, 1),
(0.0, 2),
(0.5, 3),
]
)
def test_DDIM_noiseProjection(stochasticity, K):
schedule = CDPDiscrete( beta_schedule=get_beta_schedule('linear'), noise_steps=10 )
model = _fake_model(schedule.noise_steps)
x, fakeModel = model['x'], model['fakeModel']
ddimA = _fake_DDIM(stochasticity=stochasticity, K=K, noiseProjection=False)
ddimB = _fake_DDIM(stochasticity=stochasticity, K=K, noiseProjection=True)
A = ddimA.sample(value=x, model=fakeModel, schedule=schedule)
B = ddimB.sample(value=x, model=fakeModel, schedule=schedule)
tf.debugging.assert_near(A, B, atol=1e-6)
return
# verify that stochasticity has an effect even if noise is zero
def test_DDIM_stochasticity_effect():
schedule = CDPDiscrete( beta_schedule=get_beta_schedule('linear'), noise_steps=10 )
model = _fake_model(schedule.noise_steps)
x, fakeModel = model['x'], model['fakeModel']
ddimA = _fake_DDIM(stochasticity=0.0, K=1)
ddimB = _fake_DDIM(stochasticity=1.0, K=1)
A = ddimA.sample(value=x, model=fakeModel, schedule=schedule)
B = ddimB.sample(value=x, model=fakeModel, schedule=schedule)
tf.debugging.assert_greater(tf.reduce_mean(tf.abs(A - B)), 1e-3)
return