Skip to content

Commit

Permalink
make sure to conditionally drop out phoneme ids for CFG
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 4, 2023
1 parent c20d85b commit 43a90e8
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def __init__(
attn_flash = False
):
super().__init__()
self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens, dim_phoneme_emb)

self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG
self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb)

self.to_embed = nn.Linear(dim * 2 + dim_phoneme_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim))
Expand Down Expand Up @@ -273,7 +276,6 @@ def forward(
target = None,
mask = None
):
phoneme_emb = self.to_phoneme_emb(phoneme_ids)
assert cond.shape[-1] == x.shape[-1]

# classifier free guidance
Expand All @@ -287,6 +289,14 @@ def forward(
cond
)

phoneme_ids = torch.where(
rearrange(cond_drop_mask, '... -> ... 1'),
self.null_phoneme_id,
phoneme_ids
)

phoneme_emb = self.to_phoneme_emb(phoneme_ids)

# combine audio, phoneme, conditioning

embed = torch.cat((x, phoneme_emb, cond), dim = -1)
Expand Down Expand Up @@ -331,7 +341,9 @@ def __init__(
super().__init__()
self.sinu_pos_emb = LearnedSinusoidalPosEmb(dim)

self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens, dim_phoneme_emb)
self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG
self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb)

self.to_embed = nn.Linear(dim * 2 + dim_phoneme_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim))
Expand Down Expand Up @@ -373,11 +385,10 @@ def forward(
phoneme_ids,
cond,
times,
cond_drop_prob = 0.,
cond_drop_prob = 0.1,
target = None,
mask = None,
):
phoneme_emb = self.to_phoneme_emb(phoneme_ids)
assert cond.shape[-1] == x.shape[-1]

# classifier free guidance
Expand All @@ -391,6 +402,13 @@ def forward(
cond
)

phoneme_ids = torch.where(
rearrange(cond_drop_mask, '... -> ... 1'),
self.null_phoneme_id,
phoneme_ids
)

phoneme_emb = self.to_phoneme_emb(phoneme_ids)
embed = torch.cat((x, phoneme_emb, cond), dim = -1)
x = self.to_embed(embed)

Expand Down

0 comments on commit 43a90e8

Please sign in to comment.