Skip to content

Commit

Permalink
just make sure node.traj runs for starters
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 4, 2023
1 parent 64df9a3 commit 8d97526
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ loss = cfm_wrapper(
)

loss.backward()

# after much training above...

sampled = cfm_wrapper.sample(
phoneme_ids = phonemes,
cond = x,
mask = mask
)
```

## Todo
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
63 changes: 56 additions & 7 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,14 @@ def forward(
):
assert cond.shape[-1] == x.shape[-1]

# auto manage shape of times, for node.traj

if times.ndim == 0:
times = repeat(times, '-> b', b = cond.shape[0])

if times.ndim == 1 and times.shape[0] == 1:
times = repeat(times, '1 -> b', b = cond.shape[0])

# classifier free guidance

if cond_drop_prob > 0.:
Expand Down Expand Up @@ -462,20 +470,60 @@ def __init__(
node_solver = 'dopri5',
node_sensitivity = 'adjoint',
node_atol = 1e-5,
node_rtol = 1e-5
node_rtol = 1e-5,
cond_drop_prob = 0.
):
super().__init__()
self.sigma = sigma

self.voicebox = voicebox
self.cond_drop_prob = cond_drop_prob

self.node = NeuralODE(
voicebox,
self.node_kwargs = dict(
solver = node_solver,
sensitivity = node_sensitivity,
atol = node_atol,
rtol = node_rtol
)

@property
def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def sample(
self,
*,
phoneme_ids,
cond,
mask = None,
steps = 18,
cond_scale = 1.
):
self.voicebox.eval()

def voicebox_forward_with_cond_scale(t, x):
return self.voicebox.forward_with_cond_scale(
x,
phoneme_ids = phoneme_ids,
times = t,
cond = cond,
mask = mask,
cond_scale = cond_scale
)

node = NeuralODE(
voicebox_forward_with_cond_scale,
**self.node_kwargs
)

traj = node.trajectory(
torch.randn_like(cond),
t_span = torch.linspace(0, 1, steps, device = self.device)
)

return traj

def forward(
self,
x1,
Expand All @@ -495,12 +543,12 @@ def forward(
# random times

times = torch.rand((batch,), dtype = dtype)
padded_times = rearrange(times, 'b -> b 1 1')

# sample xt

mu_t = times * x1
sigma_t = 1 - (1 - self.sigma) * times
sigma_t = rearrange(sigma_t, 'b -> b 1 1')
mu_t = padded_times * x1
sigma_t = 1 - (1 - self.sigma) * padded_times

eps = torch.rand_like(x1)
xt = mu_t * sigma_t * eps
Expand All @@ -517,7 +565,8 @@ def forward(
cond = cond,
mask = mask,
times = times,
target = conditional_flow
target = conditional_flow,
cond_drop_prob = self.cond_drop_prob
)

return loss

0 comments on commit 8d97526

Please sign in to comment.