From c05a4d0c69920993b47069e22223677174d873e4 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 31 Jan 2024 09:34:34 -0800 Subject: [PATCH] add a post layernorm for the gateloop layers, in case it is causing instability --- setup.py | 4 ++-- voicebox_pytorch/voicebox_pytorch.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 87c833e..df3acea 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.4.12', + version = '0.5.0', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', @@ -21,7 +21,7 @@ 'naturalspeech2-pytorch>=0.1.8', 'beartype', 'einops>=0.6.1', - 'gateloop-transformer>=0.0.25', + 'gateloop-transformer>=0.2.4', 'spear-tts-pytorch>=0.4.0', 'torch>=2.0', 'torchdiffeq', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 701389a..420eac3 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -396,7 +396,7 @@ def __init__( self.layers.append(nn.ModuleList([ nn.Linear(dim * 2, dim) if has_skip else None, - GateLoop(dim = dim, use_jax_associative_scan = gateloop_use_jax) if use_gateloop_layers else None, + GateLoop(dim = dim, use_jax_associative_scan = gateloop_use_jax, post_ln = True) if use_gateloop_layers else None, rmsnorm_klass(dim = dim), Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash, qk_norm = attn_qk_norm), rmsnorm_klass(dim = dim),