diff --git a/setup.py b/setup.py index 3286b76..f9311eb 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', @@ -19,6 +19,7 @@ 'beartype', 'einops>=0.6.1', 'torch>=2.0', + 'torchdiffeq', 'torchdyn==1.0.3', 'torchode' ], diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 9bdeb18..9cca529 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -583,25 +583,26 @@ def forward( mask = None, ): """ - following the example put forth - https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/conditional_flow_matching.py#L248 + following eq (5) (6) in https://arxiv.org/pdf/2306.15687.pdf + using https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/conditional_flow_matching.py as reference """ - batch, seq_len, dtype = *x1.shape[:2], x1.dtype + + batch, seq_len, dtype, σ = *x1.shape[:2], x1.dtype, self.sigma + + x0 = torch.randn_like(x1) # random times times = torch.rand((batch,), dtype = dtype) - padded_times = rearrange(times, 'b -> b 1 1') - - # sample xt + t = rearrange(times, 'b -> b 1 1') - mu_t = padded_times * x1 - sigma_t = 1 - (1 - self.sigma) * padded_times + # sample xt (w in the paper) - eps = torch.rand_like(x1) - xt = mu_t + sigma_t * eps + mu_t = t * x1 + σt = (1 - (1 - σ) * t) * x0 + w = mu_t + σt - flow = (x1 - (1 - self.sigma) * xt) / sigma_t + flow = x1 - (1 - σ) * x0 # construct mask if not given @@ -618,7 +619,7 @@ def forward( self.voicebox.train() loss = self.voicebox( - xt, + w, phoneme_ids = phoneme_ids, cond = cond, mask = mask,