Skip to content

Commit

Permalink
[transformer] add rms-norm (#2396)
Browse files Browse the repository at this point in the history
* [transformer] add rms-norm

* fix assert
  • Loading branch information
Mddct authored Mar 8, 2024
1 parent e5fd5c0 commit d01715a
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 25 deletions.
8 changes: 5 additions & 3 deletions wenet/transformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
Expand Down Expand Up @@ -68,13 +70,13 @@ def __init__(self,
bias=bias,
)

assert norm in ['batch_norm', 'layer_norm']
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
self.norm = WENET_NORM_CLASSES['batch_norm'](channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.norm = WENET_NORM_CLASSES[norm](channels)

self.pointwise_conv2 = nn.Conv1d(
channels,
Expand Down
6 changes: 5 additions & 1 deletion wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -93,8 +95,10 @@ def __init__(
positional_dropout_rate),
)

assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
eps=1e-5)
self.use_output_layer = use_output_layer
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
Expand Down
10 changes: 7 additions & 3 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class DecoderLayer(nn.Module):
"""Single decoder layer module.
Expand Down Expand Up @@ -46,16 +48,18 @@ def __init__(
feed_forward: nn.Module,
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
):
"""Construct an DecoderLayer object."""
super().__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
self.norm3 = nn.LayerNorm(size, eps=1e-5)
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before

Expand Down
29 changes: 18 additions & 11 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from wenet.utils.class_utils import (
WENET_EMB_CLASSES,
WENET_MLP_CLASSES,
WENET_NORM_CLASSES,
WENET_SUBSAMPLE_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
):
"""
Args:
Expand Down Expand Up @@ -102,8 +104,10 @@ def __init__(
positional_dropout_rate),
)

assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
Expand Down Expand Up @@ -368,6 +372,7 @@ def __init__(
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
):
""" Construct TransformerEncoder
Expand All @@ -379,19 +384,21 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
use_sdpa, layer_norm_type)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias), dropout_rate, normalize_before)
TransformerEncoderLayer(output_size,
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, output_size,
attention_dropout_rate, query_bias,
key_bias, value_bias, use_sdpa),
mlp_class(output_size, linear_units,
dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type)
for _ in range(num_blocks)
])

Expand Down
24 changes: 17 additions & 7 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
from torch import nn

from wenet.utils.class_utils import WENET_NORM_CLASSES


class TransformerEncoderLayer(nn.Module):
"""Encoder layer module.
Expand All @@ -44,13 +46,15 @@ def __init__(
feed_forward: torch.nn.Module,
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = nn.LayerNorm(size, eps=1e-5)
self.norm2 = nn.LayerNorm(size, eps=1e-5)
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down Expand Up @@ -135,23 +139,29 @@ def __init__(
conv_module: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
self.norm_ff = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the FNN module
self.norm_mha = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size, eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
self.norm_conv = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the CNN module
self.norm_final = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
Expand Down
22 changes: 22 additions & 0 deletions wenet/transformer/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch


class RMSNorm(torch.nn.Module):
""" https://arxiv.org/pdf/1910.07467.pdf
"""

def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
x = self._norm(x.float()).type_as(x)
return x * self.weight
8 changes: 8 additions & 0 deletions wenet/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# -*- coding: utf-8 -*-
# Copyright [2023-11-28] <[email protected], Xingchen Song>
import torch
from torch.nn import BatchNorm1d, LayerNorm
from wenet.paraformer.embedding import ParaformerPositinoalEncoding
from wenet.transformer.norm import RMSNorm
from wenet.transformer.positionwise_feed_forward import (
GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward)

Expand Down Expand Up @@ -77,3 +79,9 @@
'moe': MoEFFNLayer,
'gated': GatedVariantsMLP
}

WENET_NORM_CLASSES = {
'layer_norm': LayerNorm,
'batch_norm': BatchNorm1d,
'rms_norm': RMSNorm
}

0 comments on commit d01715a

Please sign in to comment.