From 47488d6dcf69bc540e520840f751fb51da21007b Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 4 Aug 2023 13:35:57 -0700 Subject: [PATCH] get voicebox to be trainable --- README.md | 47 ++++++++++++++++++- setup.py | 1 - voicebox_pytorch/__init__.py | 2 +- voicebox_pytorch/voicebox_pytorch.py | 68 ++++++++++++++++++++++++++-- 4 files changed, 110 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index f4873f4..e2832df 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,54 @@ Implementation of Voicebox, new S In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models. +## Install + +```bash +$ pip install voicebox-pytorch +``` + +## Usage + +```python +import torch +from voicebox_pytorch.voicebox_pytorch import ( + VoiceBox, + ConditionalFlowMatcherWrapper +) + +model = VoiceBox( + dim = 512, + num_phoneme_tokens = 256, + depth = 2, + dim_head = 64, + heads = 16 +) + +cfm_wrapper = ConditionalFlowMatcherWrapper( + voicebox = model +) + +x = torch.randn(1, 1024, 512) +phonemes = torch.randint(0, 256, (1, 1024)) +mask = torch.randint(0, 2, (1, 1024)) + +loss = cfm_wrapper( + x, + phoneme_ids = phonemes, + cond = x, + mask = mask +) + +loss.backward() +``` + ## Todo +- [x] read and internalize original flow matching paper + - [x] basic loss + - [ ] get neural ode working with torchdyn - [ ] consider switching to adaptive rmsnorm for time conditioning -- [ ] read and internalize original flow matching paper and build out basic training code -- [ ] take care of mel spec + inverse mel spec +- [ ] integrate with either hifi-gan or soundstream / encodec - [ ] basic trainer ## Citations diff --git a/setup.py b/setup.py index ccaecd6..afb756f 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,6 @@ 'beartype', 'einops>=0.6.1', 'torch>=2.0', - 'torchdiffeq', 'torchdyn==1.0.3' ], classifiers=[ diff --git a/voicebox_pytorch/__init__.py b/voicebox_pytorch/__init__.py index bc7b05c..208b07e 100644 --- a/voicebox_pytorch/__init__.py +++ b/voicebox_pytorch/__init__.py @@ -2,5 +2,5 @@ Transformer, VoiceBox, DurationPredictor, - CNFWrapper + ConditionalFlowMatcherWrapper ) diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 75287b4..0327926 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -10,6 +10,8 @@ from voicebox_pytorch.attend import Attend +from torchdyn.core import NeuralODE + # helper functions def exists(val): @@ -451,13 +453,71 @@ def forward( # wrapper for the CNF -class CNFWrapper(Module): +class ConditionalFlowMatcherWrapper(Module): @beartype def __init__( self, - voicebox: VoiceBox + voicebox: VoiceBox, + sigma = 0., + node_solver = 'dopri5', + node_sensitivity = 'adjoint', + node_atol = 1e-5, + node_rtol = 1e-5 ): super().__init__() + self.sigma = sigma + self.voicebox = voicebox + + self.node = NeuralODE( + voicebox, + solver = node_solver, + sensitivity = node_sensitivity, + atol = node_atol, + rtol = node_rtol + ) - def forward(self, x): - return x + def forward( + self, + x1, + *, + phoneme_ids, + cond, + mask = None, + ): + """ + following the example put forth + https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/conditional_flow_matching.py#L248 + """ + batch, dtype = x1.shape[0], x1.dtype + + x0 = torch.randn_like(x1) + + # random times + + times = torch.rand((batch,), dtype = dtype) + + # sample xt + + mu_t = times * x1 + sigma_t = 1 - (1 - self.sigma) * times + sigma_t = rearrange(sigma_t, 'b -> b 1 1') + + eps = torch.rand_like(x1) + xt = mu_t * sigma_t * eps + + conditional_flow = (x1 - (1 - self.sigma) * xt) / sigma_t + + # predict + + self.voicebox.train() + + loss = self.voicebox( + xt, + phoneme_ids = phoneme_ids, + cond = cond, + mask = mask, + times = times, + target = conditional_flow + ) + + return loss