Skip to content

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch

License

Notifications You must be signed in to change notification settings

lucidrains/voicebox-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Voicebox - Pytorch

Implementation of Voicebox, new SOTA Text-to-Speech model from MetaAI, in Pytorch. Press release

In this work, we will use rotary embeddings. The authors seem unaware that ALiBi cannot be straightforwardly used for bidirectional models.

The paper also addresses the issue with time embedding incorrectly subjected to relative distances (they concat the time embedding along the frame dimension of the audio tokens). This repository will use adaptive normalization, as applied successfully in Paella

Update: Recommend you just use E2 TTS instead of this work

Appreciation

  • Translated for awarding me the Imminent Grant to advance the state of open sourced text-to-speech solutions. This project was started and will be completed under this grant.

  • StabilityAI for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • Bryan Chiang for the ongoing code review, sharing his expertise on TTS, and pointing me to an open sourced implementation of conditional flow matching

  • Manmay for getting the repository started with the alignment code

  • @chenht2010 for finding a bug with rotary positions, and for validating that the code in the repository converges

  • Lucas Newman for (yet again) pull requesting all the training code for Spear-TTS conditioned Voicebox training!

  • Lucas Newman has demonstrated that the whole system works with Spear-TTS conditioning. Training converges even better than Soundstorm

Install

$ pip install voicebox-pytorch

Usage

Training and sampling with TextToSemantic module from SpearTTS

import torch

from voicebox_pytorch import (
    VoiceBox,
    EncodecVoco,
    ConditionalFlowMatcherWrapper,
    HubertWithKmeans,
    TextToSemantic
)

# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

wav2vec = HubertWithKmeans(
    checkpoint_path = '/path/to/hubert/checkpoint.pt',
    kmeans_path = '/path/to/hubert/kmeans.bin'
)

text_to_semantic = TextToSemantic(
    wav2vec = wav2vec,
    dim = 512,
    source_depth = 1,
    target_depth = 1,
    use_openai_tokenizer = True
)

text_to_semantic.load('/path/to/trained/spear-tts/model.pt')

model = VoiceBox(
    dim = 512,
    audio_enc_dec = EncodecVoco(),
    num_cond_tokens = 500,
    depth = 2,
    dim_head = 64,
    heads = 16
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
    voicebox = model,
    text_to_semantic = text_to_semantic
)

# mock data

audio = torch.randn(2, 12000)

# train

loss = cfm_wrapper(audio)
loss.backward()

# after much training

texts = [
    'the rain in spain falls mainly in the plains',
    'she sells sea shells by the seashore'
]

cond = torch.randn(2, 12000)
sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1, <audio length>)

For unconditional training, condition_on_text on VoiceBox must be set to False

import torch
from voicebox_pytorch import (
    VoiceBox,
    ConditionalFlowMatcherWrapper
)

model = VoiceBox(
    dim = 512,
    num_cond_tokens = 500,
    depth = 2,
    dim_head = 64,
    heads = 16,
    condition_on_text = False
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
    voicebox = model
)

# mock data

x = torch.randn(2, 1024, 512)

# train

loss = cfm_wrapper(x)

loss.backward()

# after much training

cond = torch.randn(2, 1024, 512)

sampled = cfm_wrapper.sample(cond = cond) # (2, 1024, 512)

Todo

  • read and internalize original flow matching paper

    • basic loss
    • get neural ode working with torchdyn
  • get basic mask generation logic with the p_drop of 0.2-0.3 for ICL

  • take care of p_drop, different between voicebox and duration model

  • support torchdiffeq and torchode

  • switch to adaptive rmsnorm for time conditioning

  • add encodec / voco for starters

  • setup training and sampling with raw audio, if audio_enc_dec is passed in

  • integrate with log mel spec / encodec - vocos

  • spear-tts-integration

  • basic accelerate trainer - thanks to @lucasnewman!

  • cleanup NS2 aligner class and then setup duration predictor training

  • figure out the correct settings for MelVoco encode, as the reconstructed audio is longer in length

  • calculate how many seconds corresponds to each frame and add as property on AudioEncoderDecoder - when sampling, allow for specifying in seconds

Citations

@article{Le2023VoiceboxTM,
    title   = {Voicebox: Text-Guided Multilingual Universal Speech Generation at Scale},
    author  = {Matt Le and Apoorv Vyas and Bowen Shi and Brian Karrer and Leda Sari and Rashel Moritz and Mary Williamson and Vimal Manohar and Yossi Adi and Jay Mahadeokar and Wei-Ning Hsu},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.15687},
    url     = {https://api.semanticscholar.org/CorpusID:259275061}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@misc{torchdiffeq,
    author  = {Chen, Ricky T. Q.},
    title   = {torchdiffeq},
    year    = {2018},
    url     = {https://github.com/rtqichen/torchdiffeq},
}
@inproceedings{lienen2022torchode,
    title     = {torchode: A Parallel {ODE} Solver for PyTorch},
    author    = {Marten Lienen and Stephan G{\"u}nnemann},
    booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
    year      = {2022},
    url       = {https://openreview.net/forum?id=uiKVKTiUYB0}
}
@article{siuzdak2023vocos,
    title   = {Vocos: Closing the gap between time-domain and Fourier-based neural vocoders for high-quality audio synthesis},
    author  = {Siuzdak, Hubert},
    journal = {arXiv preprint arXiv:2306.00814},
    year    = {2023}
}
@misc{darcet2023vision,
    title   = {Vision Transformers Need Registers},
    author  = {Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    eprint  = {2309.16588},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Dehghani2023ScalingVT,
    title   = {Scaling Vision Transformers to 22 Billion Parameters},
    author  = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
    booktitle = {International Conference on Machine Learning},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:256808367}
}
@inproceedings{Katsch2023GateLoopFD,
    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},
    author  = {Tobias Katsch},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:265018962}
}