Skip to content

Commit

Permalink
Merge pull request #37 from lucasnewman/null-cond-no-grad
Browse files Browse the repository at this point in the history
Disable gradients for null conditioning when CFG is enabled
  • Loading branch information
lucidrains authored Nov 26, 2023
2 parents 86d1310 + 8377f26 commit 606de71
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def mask_from_frac_lengths(
seq_len: int,
frac_lengths: Tensor
):
device = frac_lengths
device = frac_lengths.device

lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths

rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1)
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
start = (max_start * rand).clamp(min = 0)
end = start + lengths

Expand Down Expand Up @@ -562,9 +562,13 @@ def encode(self, audio):
encoded_audio, _, _ = self.encodec(audio, return_encoded = True)
return encoded_audio

def decode(self, latents):
def decode_to_codes(self, latents):
_, codes, _ = self.encodec.rq(latents)
codes = rearrange(codes, 'b n q -> b q n')
return codes

def decode(self, latents):
codes = self.decode_to_codes(latents)

all_audios = []
for code in codes:
Expand Down Expand Up @@ -631,7 +635,7 @@ def __init__(

self.to_embed = nn.Linear(dim + dim_phoneme_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim))
self.null_cond = nn.Parameter(torch.zeros(dim), requires_grad = False)

self.conv_embed = ConvPositionEmbed(
dim = dim,
Expand Down Expand Up @@ -835,7 +839,7 @@ def forward(
else:
loss_mask = self_attn_mask

if not exists(mask):
if not exists(loss_mask):
return F.l1_loss(x, target)

loss = F.l1_loss(x, target, reduction = 'none')
Expand All @@ -848,7 +852,7 @@ def forward(
loss = num / den
loss = loss.mean()

if not should_align:
if not return_aligned_phoneme_ids:
return loss

#aligner loss
Expand Down Expand Up @@ -920,7 +924,7 @@ def __init__(

self.to_embed = nn.Linear(dim_in * 2 + dim_cond_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim_in))
self.null_cond = nn.Parameter(torch.zeros(dim_in), requires_grad = False)

self.conv_embed = ConvPositionEmbed(
dim = dim,
Expand Down Expand Up @@ -1174,6 +1178,7 @@ def sample(
steps = 3,
cond_scale = 1.,
decode_to_audio = True,
decode_to_codes = False,
max_semantic_token_ids = 2048,
spec_decode = False,
spec_decode_gamma = 5 # could be higher, since speech is probably easier than text, needs to be tested
Expand Down Expand Up @@ -1311,6 +1316,9 @@ def fn(t, x, *, packed_shape = None):
sampled = sol.ys[:, -1]
sampled = unpack_one(sampled, packed_shape, 'b *')

if decode_to_codes and exists(self.voicebox.audio_enc_dec):
return self.voicebox.audio_enc_dec.decode_to_codes(sampled)

if not decode_to_audio or not exists(self.voicebox.audio_enc_dec):
return sampled

Expand Down

0 comments on commit 606de71

Please sign in to comment.