Skip to content

Commit

Permalink
bigvgan
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMax2016 committed May 29, 2023
1 parent d4b76cb commit 46e4f84
Show file tree
Hide file tree
Showing 12 changed files with 378 additions and 24 deletions.
3 changes: 1 addition & 2 deletions vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def forward(self, ppg, pit, spec, spk, ppg_l, spec_l):
return audio, ids_slice, spec_mask, (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r), spk_preds

def infer(self, ppg, pit, spk, ppg_l):
ppg = ppg + torch.randn_like(ppg) * 0.0001 # Perturbation
z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
ppg, ppg_l, f0=f0_to_coarse(pit))
z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True)
Expand Down Expand Up @@ -241,10 +242,8 @@ def source2wav(self, source):
return self.dec.source2wav(source)

def inference(self, ppg, pit, spk, ppg_l, source):
ppg = ppg + torch.randn_like(ppg) * 0.0001 # Perturbation
z_p, m_p, logs_p, ppg_mask, x = self.enc_p(
ppg, ppg_l, f0=f0_to_coarse(pit))
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * 0.7
z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True)
o = self.dec.inference(spk, z * ppg_mask, source)
return o
1 change: 1 addition & 0 deletions vits_decoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .alias.act import SnakeAlias
6 changes: 6 additions & 0 deletions vits_decoder/alias/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

from .filter import *
from .resample import *
from .act import *
129 changes: 129 additions & 0 deletions vits_decoder/alias/act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import sin, pow
from torch.nn import Parameter
from .resample import UpSample1d, DownSample1d


class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)

# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)

return x


class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''

def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001

def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(
0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x


class Mish(nn.Module):
"""
Mish activation function is proposed in "Mish: A Self
Regularized Non-Monotonic Neural Activation Function"
paper, https://arxiv.org/abs/1908.08681.
"""

def __init__(self):
super().__init__()

def forward(self, x):
return x * torch.tanh(F.softplus(x))


class SnakeAlias(nn.Module):
def __init__(self,
channels,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = SnakeBeta(channels, alpha_logscale=True)
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)

# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)

return x
95 changes: 95 additions & 0 deletions vits_decoder/alias/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)


# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2

#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)

# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)

return filter


class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)

#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape

if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1),
stride=self.stride, groups=C)

return out
49 changes: 49 additions & 0 deletions vits_decoder/alias/resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d


class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)

# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape

x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]

return x


class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)

def forward(self, x):
xx = self.lowpass(x)

return xx
18 changes: 13 additions & 5 deletions vits_decoder/bigv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch import nn
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from .alias.act import SnakeAlias


def init_weights(m, mean=0.0, std=0.01):
Expand Down Expand Up @@ -40,11 +39,20 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
])
self.convs2.apply(init_weights)

# total number of conv layers
self.num_layers = len(self.convs1) + len(self.convs2)

# periodic nonlinearity with snakebeta function and anti-aliasing
self.activations = nn.ModuleList([
SnakeAlias(channels) for _ in range(self.num_layers)
])

def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, 0.1)
acts1, acts2 = self.activations[::2], self.activations[1::2]
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
xt = a1(x)
xt = c1(xt)
xt = F.leaky_relu(xt, 0.1)
xt = a2(xt)
xt = c2(xt)
x = xt + x
return x
Expand Down
8 changes: 4 additions & 4 deletions vits_decoder/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn

from omegaconf import OmegaConf

from .msd import ScaleDiscriminator
from .mpd import MultiPeriodDiscriminator
from .mrd import MultiResolutionDiscriminator

Expand All @@ -12,13 +12,13 @@ def __init__(self, hp):
super(Discriminator, self).__init__()
self.MRD = MultiResolutionDiscriminator(hp)
self.MPD = MultiPeriodDiscriminator(hp)

self.MSD = ScaleDiscriminator()

def forward(self, x):
r = self.MRD(x)
p = self.MPD(x)

return r + p
s = self.MSD(x)
return r + p + s


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 46e4f84

Please sign in to comment.