Skip to content

Commit

Permalink
add a post layernorm for the gateloop layers, in case it is causing i…
Browse files Browse the repository at this point in the history
…nstability
  • Loading branch information
lucidrains committed Jan 31, 2024
1 parent cdc0777 commit c05a4d0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.12',
version = '0.5.0',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand All @@ -21,7 +21,7 @@
'naturalspeech2-pytorch>=0.1.8',
'beartype',
'einops>=0.6.1',
'gateloop-transformer>=0.0.25',
'gateloop-transformer>=0.2.4',
'spear-tts-pytorch>=0.4.0',
'torch>=2.0',
'torchdiffeq',
Expand Down
2 changes: 1 addition & 1 deletion voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def __init__(

self.layers.append(nn.ModuleList([
nn.Linear(dim * 2, dim) if has_skip else None,
GateLoop(dim = dim, use_jax_associative_scan = gateloop_use_jax) if use_gateloop_layers else None,
GateLoop(dim = dim, use_jax_associative_scan = gateloop_use_jax, post_ln = True) if use_gateloop_layers else None,
rmsnorm_klass(dim = dim),
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm),
rmsnorm_klass(dim = dim),
Expand Down

0 comments on commit c05a4d0

Please sign in to comment.