Skip to content

Commit

Permalink
get voicebox to be trainable
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 4, 2023
1 parent 43a90e8 commit 47488d6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 8 deletions.
47 changes: 45 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,54 @@ Implementation of <a href="https://arxiv.org/abs/2306.15687">Voicebox</a>, new S

In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models.

## Install

```bash
$ pip install voicebox-pytorch
```

## Usage

```python
import torch
from voicebox_pytorch.voicebox_pytorch import (
VoiceBox,
ConditionalFlowMatcherWrapper
)

model = VoiceBox(
dim = 512,
num_phoneme_tokens = 256,
depth = 2,
dim_head = 64,
heads = 16
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
voicebox = model
)

x = torch.randn(1, 1024, 512)
phonemes = torch.randint(0, 256, (1, 1024))
mask = torch.randint(0, 2, (1, 1024))

loss = cfm_wrapper(
x,
phoneme_ids = phonemes,
cond = x,
mask = mask
)

loss.backward()
```

## Todo

- [x] read and internalize original flow matching paper
- [x] basic loss
- [ ] get neural ode working with torchdyn
- [ ] consider switching to adaptive rmsnorm for time conditioning
- [ ] read and internalize original flow matching paper and build out basic training code
- [ ] take care of mel spec + inverse mel spec
- [ ] integrate with either hifi-gan or soundstream / encodec
- [ ] basic trainer

## Citations
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
'beartype',
'einops>=0.6.1',
'torch>=2.0',
'torchdiffeq',
'torchdyn==1.0.3'
],
classifiers=[
Expand Down
2 changes: 1 addition & 1 deletion voicebox_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
Transformer,
VoiceBox,
DurationPredictor,
CNFWrapper
ConditionalFlowMatcherWrapper
)
68 changes: 64 additions & 4 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from voicebox_pytorch.attend import Attend

from torchdyn.core import NeuralODE

# helper functions

def exists(val):
Expand Down Expand Up @@ -451,13 +453,71 @@ def forward(

# wrapper for the CNF

class CNFWrapper(Module):
class ConditionalFlowMatcherWrapper(Module):
@beartype
def __init__(
self,
voicebox: VoiceBox
voicebox: VoiceBox,
sigma = 0.,
node_solver = 'dopri5',
node_sensitivity = 'adjoint',
node_atol = 1e-5,
node_rtol = 1e-5
):
super().__init__()
self.sigma = sigma
self.voicebox = voicebox

self.node = NeuralODE(
voicebox,
solver = node_solver,
sensitivity = node_sensitivity,
atol = node_atol,
rtol = node_rtol
)

def forward(self, x):
return x
def forward(
self,
x1,
*,
phoneme_ids,
cond,
mask = None,
):
"""
following the example put forth
https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/conditional_flow_matching.py#L248
"""
batch, dtype = x1.shape[0], x1.dtype

x0 = torch.randn_like(x1)

# random times

times = torch.rand((batch,), dtype = dtype)

# sample xt

mu_t = times * x1
sigma_t = 1 - (1 - self.sigma) * times
sigma_t = rearrange(sigma_t, 'b -> b 1 1')

eps = torch.rand_like(x1)
xt = mu_t * sigma_t * eps

conditional_flow = (x1 - (1 - self.sigma) * xt) / sigma_t

# predict

self.voicebox.train()

loss = self.voicebox(
xt,
phoneme_ids = phoneme_ids,
cond = cond,
mask = mask,
times = times,
target = conditional_flow
)

return loss

0 comments on commit 47488d6

Please sign in to comment.