From 8377f26d589cdf7b92ca22bf5c92fd11e5976139 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Sat, 25 Nov 2023 16:26:58 -0800 Subject: [PATCH] Disable gradients for null conditioning when CFG is enabled. --- voicebox_pytorch/voicebox_pytorch.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index eb890ab..8e0d700 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -137,12 +137,12 @@ def mask_from_frac_lengths( seq_len: int, frac_lengths: Tensor ): - device = frac_lengths + device = frac_lengths.device lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths - rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1) + rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) start = (max_start * rand).clamp(min = 0) end = start + lengths @@ -562,9 +562,13 @@ def encode(self, audio): encoded_audio, _, _ = self.encodec(audio, return_encoded = True) return encoded_audio - def decode(self, latents): + def decode_to_codes(self, latents): _, codes, _ = self.encodec.rq(latents) codes = rearrange(codes, 'b n q -> b q n') + return codes + + def decode(self, latents): + codes = self.decode_to_codes(latents) all_audios = [] for code in codes: @@ -631,7 +635,7 @@ def __init__( self.to_embed = nn.Linear(dim + dim_phoneme_emb, dim) - self.null_cond = nn.Parameter(torch.zeros(dim)) + self.null_cond = nn.Parameter(torch.zeros(dim), requires_grad = False) self.conv_embed = ConvPositionEmbed( dim = dim, @@ -835,7 +839,7 @@ def forward( else: loss_mask = self_attn_mask - if not exists(mask): + if not exists(loss_mask): return F.l1_loss(x, target) loss = F.l1_loss(x, target, reduction = 'none') @@ -848,7 +852,7 @@ def forward( loss = num / den loss = loss.mean() - if not should_align: + if not return_aligned_phoneme_ids: return loss #aligner loss @@ -920,7 +924,7 @@ def __init__( self.to_embed = nn.Linear(dim_in * 2 + dim_cond_emb, dim) - self.null_cond = nn.Parameter(torch.zeros(dim_in)) + self.null_cond = nn.Parameter(torch.zeros(dim_in), requires_grad = False) self.conv_embed = ConvPositionEmbed( dim = dim, @@ -1174,6 +1178,7 @@ def sample( steps = 3, cond_scale = 1., decode_to_audio = True, + decode_to_codes = False, max_semantic_token_ids = 2048, spec_decode = False, spec_decode_gamma = 5 # could be higher, since speech is probably easier than text, needs to be tested @@ -1311,6 +1316,9 @@ def fn(t, x, *, packed_shape = None): sampled = sol.ys[:, -1] sampled = unpack_one(sampled, packed_shape, 'b *') + if decode_to_codes and exists(self.voicebox.audio_enc_dec): + return self.voicebox.audio_enc_dec.decode_to_codes(sampled) + if not decode_to_audio or not exists(self.voicebox.audio_enc_dec): return sampled