removed redundant files

This commit is contained in:
jinzr 2023-12-01 00:16:00 +08:00
parent 0f051f5518
commit 9931694455
9 changed files with 0 additions and 3444 deletions

View File

@ -1,194 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Stochastic duration predictor modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
import torch
import torch.nn.functional as F
from flow import (
ConvFlow,
DilatedDepthSeparableConv,
ElementwiseAffineFlow,
FlipFlow,
LogFlow,
)
class StochasticDurationPredictor(torch.nn.Module):
"""Stochastic duration predictor module.
This is a module of stochastic duration predictor described in `Conditional
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
channels: int = 192,
kernel_size: int = 3,
dropout_rate: float = 0.5,
flows: int = 4,
dds_conv_layers: int = 3,
global_channels: int = -1,
):
"""Initialize StochasticDurationPredictor module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
dropout_rate (float): Dropout rate.
flows (int): Number of flows.
dds_conv_layers (int): Number of conv layers in DDS conv.
global_channels (int): Number of global conditioning channels.
"""
super().__init__()
self.pre = torch.nn.Conv1d(channels, channels, 1)
self.dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate,
)
self.proj = torch.nn.Conv1d(channels, channels, 1)
self.log_flow = LogFlow()
self.flows = torch.nn.ModuleList()
self.flows += [ElementwiseAffineFlow(2)]
for i in range(flows):
self.flows += [
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers,
)
]
self.flows += [FlipFlow()]
self.post_pre = torch.nn.Conv1d(1, channels, 1)
self.post_dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate,
)
self.post_proj = torch.nn.Conv1d(channels, channels, 1)
self.post_flows = torch.nn.ModuleList()
self.post_flows += [ElementwiseAffineFlow(2)]
for i in range(flows):
self.post_flows += [
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers,
)
]
self.post_flows += [FlipFlow()]
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
w: Optional[torch.Tensor] = None,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
noise_scale: float = 1.0,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T_text).
x_mask (Tensor): Mask tensor (B, 1, T_text).
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
inverse (bool): Whether to inverse the flow.
noise_scale (float): Noise scale value.
Returns:
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
If inverse, log-duration tensor (B, 1, T_text).
"""
x = x.detach() # stop gradient
x = self.pre(x)
if g is not None:
x = x + self.global_conv(g.detach()) # stop gradient
x = self.dds(x, x_mask)
x = self.proj(x) * x_mask
if not inverse:
assert w is not None, "w must be provided."
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(
w.size(0),
2,
w.size(2),
).to(device=x.device, dtype=x.dtype)
* x_mask
)
z_q = e_q
logdet_tot_q = 0.0
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
return nll + logq # (B,)
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(
x.size(0),
2,
x.size(2),
).to(device=x.device, dtype=x.dtype)
* noise_scale
)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)
z0, z1 = z.split(1, 1)
logw = z0
return logw

View File

@ -1,312 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Basic Flow modules used in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional, Tuple, Union
import torch
from transform import piecewise_rational_quadratic_transform
class FlipFlow(torch.nn.Module):
"""Flip flow module."""
def forward(
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Flipped tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
x = torch.flip(x, [1])
if not inverse:
logdet = x.new_zeros(x.size(0))
return x, logdet
else:
return x
class LogFlow(torch.nn.Module):
"""Log flow module."""
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
inverse: bool = False,
eps: float = 1e-5,
**kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
eps (float): Epsilon for log.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = torch.log(torch.clamp_min(x, eps)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class ElementwiseAffineFlow(torch.nn.Module):
"""Elementwise affine flow module."""
def __init__(self, channels: int):
"""Initialize ElementwiseAffineFlow module.
Args:
channels (int): Number of channels.
"""
super().__init__()
self.channels = channels
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_lengths (Tensor): Length tensor (B,).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class Transpose(torch.nn.Module):
"""Transpose module for torch.nn.Sequential()."""
def __init__(self, dim1: int, dim2: int):
"""Initialize Transpose module."""
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transpose."""
return x.transpose(self.dim1, self.dim2)
class DilatedDepthSeparableConv(torch.nn.Module):
"""Dilated depth-separable conv module."""
def __init__(
self,
channels: int,
kernel_size: int,
layers: int,
dropout_rate: float = 0.0,
eps: float = 1e-5,
):
"""Initialize DilatedDepthSeparableConv module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
dropout_rate (float): Dropout rate.
eps (float): Epsilon for layer norm.
"""
super().__init__()
self.convs = torch.nn.ModuleList()
for i in range(layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs += [
torch.nn.Sequential(
torch.nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding,
),
Transpose(1, 2),
torch.nn.LayerNorm(
channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
torch.nn.GELU(),
torch.nn.Conv1d(
channels,
channels,
1,
),
Transpose(1, 2),
torch.nn.LayerNorm(
channels,
eps=eps,
elementwise_affine=True,
),
Transpose(1, 2),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
)
]
def forward(
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, channels, T).
"""
if g is not None:
x = x + g
for f in self.convs:
y = f(x * x_mask)
x = x + y
return x * x_mask
class ConvFlow(torch.nn.Module):
"""Convolutional flow module."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
layers: int,
bins: int = 10,
tail_bound: float = 5.0,
):
"""Initialize ConvFlow module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
bins (int): Number of bins.
tail_bound (float): Tail bound value.
"""
super().__init__()
self.half_channels = in_channels // 2
self.hidden_channels = hidden_channels
self.bins = bins
self.tail_bound = tail_bound
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.dds_conv = DilatedDepthSeparableConv(
hidden_channels,
kernel_size,
layers,
dropout_rate=0.0,
)
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * (bins * 3 - 1),
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, 1)
h = self.input_conv(xa)
h = self.dds_conv(h, x_mask, g=g)
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
b, c, t = xa.shape
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
# TODO(kan-bayashi): Understand this calculation
denom = math.sqrt(self.hidden_channels)
unnorm_widths = h[..., : self.bins] / denom
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins :]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound,
)
x = torch.cat([xa, xb], 1) * x_mask
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
if not inverse:
return x, logdet
else:
return x

View File

@ -1,531 +0,0 @@
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Generator module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
aux_channels: int = 513,
hidden_channels: int = 192,
spks: Optional[int] = None,
langs: Optional[int] = None,
spk_embed_dim: Optional[int] = None,
global_channels: int = -1,
segment_size: int = 32,
text_encoder_attention_heads: int = 2,
text_encoder_ffn_expand: int = 4,
text_encoder_cnn_module_kernel: int = 5,
text_encoder_blocks: int = 6,
text_encoder_dropout_rate: float = 0.1,
decoder_kernel_size: int = 7,
decoder_channels: int = 512,
decoder_upsample_scales: List[int] = [8, 8, 2, 2],
decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
use_weight_norm_in_decoder: bool = True,
posterior_encoder_kernel_size: int = 5,
posterior_encoder_layers: int = 16,
posterior_encoder_stacks: int = 1,
posterior_encoder_base_dilation: int = 1,
posterior_encoder_dropout_rate: float = 0.0,
use_weight_norm_in_posterior_encoder: bool = True,
flow_flows: int = 4,
flow_kernel_size: int = 5,
flow_base_dilation: int = 1,
flow_layers: int = 4,
flow_dropout_rate: float = 0.0,
use_weight_norm_in_flow: bool = True,
use_only_mean_in_flow: bool = True,
stochastic_duration_predictor_kernel_size: int = 3,
stochastic_duration_predictor_dropout_rate: float = 0.5,
stochastic_duration_predictor_flows: int = 4,
stochastic_duration_predictor_dds_conv_layers: int = 3,
):
"""Initialize VITS generator module.
Args:
vocabs (int): Input vocabulary size.
aux_channels (int): Number of acoustic feature channels.
hidden_channels (int): Number of hidden channels.
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
sids will be provided as the input and use sid embedding layer.
langs (Optional[int]): Number of languages. If set to > 1, assume that the
lids will be provided as the input and use sid embedding layer.
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
assume that spembs will be provided as the input.
global_channels (int): Number of global conditioning channels.
segment_size (int): Segment size for decoder.
text_encoder_attention_heads (int): Number of heads in conformer block
of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder.
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder.
decoder_kernel_size (int): Decoder kernel size.
decoder_channels (int): Number of decoder initial channels.
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
upsampling layers in decoder.
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
in decoder.
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
resblocks in decoder.
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
decoder.
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
posterior_encoder_layers (int): Number of layers of posterior encoder.
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
normalization in posterior encoder.
flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow.
flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
flow.
use_only_mean_in_flow (bool): Whether to use only mean in flow.
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
duration predictor.
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
stochastic duration predictor.
stochastic_duration_predictor_flows (int): Number of flows in stochastic
duration predictor.
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
layers in stochastic duration predictor.
"""
super().__init__()
self.segment_size = segment_size
self.text_encoder = TextEncoder(
vocabs=vocabs,
d_model=hidden_channels,
num_heads=text_encoder_attention_heads,
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
cnn_module_kernel=text_encoder_cnn_module_kernel,
num_layers=text_encoder_blocks,
dropout=text_encoder_dropout_rate,
)
self.decoder = HiFiGANGenerator(
in_channels=hidden_channels,
out_channels=1,
channels=decoder_channels,
global_channels=global_channels,
kernel_size=decoder_kernel_size,
upsample_scales=decoder_upsample_scales,
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
resblock_dilations=decoder_resblock_dilations,
use_weight_norm=use_weight_norm_in_decoder,
)
self.posterior_encoder = PosteriorEncoder(
in_channels=aux_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
kernel_size=posterior_encoder_kernel_size,
layers=posterior_encoder_layers,
stacks=posterior_encoder_stacks,
base_dilation=posterior_encoder_base_dilation,
global_channels=global_channels,
dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder,
)
self.flow = ResidualAffineCouplingBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
flows=flow_flows,
kernel_size=flow_kernel_size,
base_dilation=flow_base_dilation,
layers=flow_layers,
global_channels=global_channels,
dropout_rate=flow_dropout_rate,
use_weight_norm=use_weight_norm_in_flow,
use_only_mean=use_only_mean_in_flow,
)
# TODO(kan-bayashi): Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor(
channels=hidden_channels,
kernel_size=stochastic_duration_predictor_kernel_size,
dropout_rate=stochastic_duration_predictor_dropout_rate,
flows=stochastic_duration_predictor_flows,
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
global_channels=global_channels,
)
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
if spk_embed_dim is not None and spk_embed_dim > 0:
assert global_channels > 0
self.spk_embed_dim = spk_embed_dim
self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
self.langs = None
if langs is not None and langs > 1:
assert global_channels > 0
self.langs = langs
self.lang_emb = torch.nn.Embedding(langs, global_channels)
# delayed import
from monotonic_align import maximum_path
self.maximum_path = maximum_path
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: torch.Tensor,
feats_lengths: torch.Tensor,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
]:
"""Calculate forward propagation.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
Tensor: Segments start index tensor (B,).
Tensor: Text mask tensor (B, 1, T_text).
Tensor: Feature mask tensor (B, 1, T_feats).
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
- Tensor: Flow hidden representation (B, H, T_feats).
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
- Tensor: Posterior encoder projected mean (B, H, T_feats).
- Tensor: Posterior encoder projected scale (B, H, T_feats).
"""
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
# calculate global conditioning
g = None
if self.spks is not None:
# speaker one-hot vector embedding: (B, global_channels, 1)
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# language one-hot vector embedding: (B, global_channels, 1)
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
# forward flow
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
# monotonic alignment search
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
# (B, 1, T_text)
neg_x_ent_1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2),
s_p_sq_r,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = torch.matmul(
z_p.transpose(1, 2),
(m_p * s_p_sq_r),
)
# (B, 1, T_text)
neg_x_ent_4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True,
)
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = (
self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1),
)
.unsqueeze(1)
.detach()
)
# forward duration predictor
w = attn.sum(2) # (B, 1, T_text)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / torch.sum(x_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
# get random segments
z_segments, z_start_idxs = get_random_segments(
z,
feats_lengths,
self.segment_size,
)
# forward decoder with random segments
wav = self.decoder(z_segments, g=g)
return (
wav,
dur_nll,
attn,
z_start_idxs,
x_mask,
y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
)
def inference(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
feats: Optional[torch.Tensor] = None,
feats_lengths: Optional[torch.Tensor] = None,
sids: Optional[torch.Tensor] = None,
spembs: Optional[torch.Tensor] = None,
lids: Optional[torch.Tensor] = None,
dur: Optional[torch.Tensor] = None,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
alpha: float = 1.0,
max_len: Optional[int] = None,
use_teacher_forcing: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (B, T_text,).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
skip the prediction of durations (i.e., teacher forcing).
noise_scale (float): Noise scale parameter for flow.
noise_scale_dur (float): Noise scale parameter for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length of acoustic feature sequence.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Tensor: Generated waveform tensor (B, T_wav).
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
Tensor: Duration tensor (B, T_text).
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
x_mask = x_mask.to(x.dtype)
g = None
if self.spks is not None:
# (B, global_channels, 1)
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if use_teacher_forcing:
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
# forward flow
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
# monotonic alignment search
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
# (B, 1, T_text)
neg_x_ent_1 = torch.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = torch.matmul(
-0.5 * (z_p**2).transpose(1, 2),
s_p_sq_r,
)
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = torch.matmul(
z_p.transpose(1, 2),
(m_p * s_p_sq_r),
)
# (B, 1, T_text)
neg_x_ent_4 = torch.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True,
)
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1),
).unsqueeze(1)
dur = attn.sum(2) # (B, 1, T_text)
# forward decoder with random segments
wav = self.decoder(z * y_mask, g=g)
else:
# duration
if dur is None:
logw = self.duration_predictor(
x,
x_mask,
g=g,
inverse=True,
noise_scale=noise_scale_dur,
)
w = torch.exp(logw) * x_mask * alpha
dur = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
y_mask = y_mask.to(x.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = self._generate_path(dur, attn_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = torch.matmul(
attn.squeeze(1),
m_p.transpose(1, 2),
).transpose(1, 2)
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = torch.matmul(
attn.squeeze(1),
logs_p.transpose(1, 2),
).transpose(1, 2)
# decoder
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, inverse=True)
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate path a.k.a. monotonic attention.
Args:
dur (Tensor): Duration tensor (B, 1, T_text).
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
Returns:
Tensor: Path tensor (B, 1, T_feats, T_text).
"""
b, _, t_y, t_x = mask.shape
cum_dur = torch.cumsum(dur, -1)
cum_dur_flat = cum_dur.view(b * t_x)
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
path = path.view(b, t_x, t_y).to(dtype=torch.float)
# path will be like (t_x = 3, t_y = 5):
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
# path = path.to(dtype=mask.dtype)
return path.unsqueeze(1).transpose(2, 3) * mask

View File

@ -1,933 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HiFi-GAN Modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import copy
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torch.nn.functional as F
class HiFiGANGenerator(torch.nn.Module):
"""HiFiGAN generator module."""
def __init__(
self,
in_channels: int = 80,
out_channels: int = 1,
channels: int = 512,
global_channels: int = -1,
kernel_size: int = 7,
upsample_scales: List[int] = [8, 8, 2, 2],
upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
use_additional_convs: bool = True,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
):
"""Initialize HiFiGANGenerator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (List[int]): List of upsampling scales.
upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
resblock_dilations (List[List[int]]): List of list of dilations for residual
blocks.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
"""
super().__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, "Kernel size must be odd number."
assert len(upsample_scales) == len(upsample_kernel_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
# define modules
self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
self.num_upsamples = len(upsample_kernel_sizes)
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.ConvTranspose1d(
channels // (2**i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
),
)
]
for j in range(len(resblock_kernel_sizes)):
self.blocks += [
ResidualBlock(
kernel_size=resblock_kernel_sizes[j],
channels=channels // (2 ** (i + 1)),
dilations=resblock_dilations[j],
bias=bias,
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
)
]
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
),
torch.nn.Tanh(),
)
if global_channels > 0:
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# reset parameters
self.reset_parameters()
def forward(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T).
"""
c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
cs = 0.0 # initialize
for j in range(self.num_blocks):
cs += self.blocks[i * self.num_blocks + j](c)
c = cs / self.num_blocks
c = self.output_conv(c)
return c
def reset_parameters(self):
"""Reset parameters.
This initialization follows the official implementation manner.
https://github.com/jik876/hifi-gan/blob/master/models.py
"""
def _reset_parameters(m: torch.nn.Module):
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
m.weight.data.normal_(0.0, 0.01)
logging.debug(f"Reset parameters in {m}.")
self.apply(_reset_parameters)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m: torch.nn.Module):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d) or isinstance(
m, torch.nn.ConvTranspose1d
):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def inference(
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Perform inference.
Args:
c (torch.Tensor): Input tensor (T, in_channels).
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
Returns:
Tensor: Output tensor (T ** upsample_factor, out_channels).
"""
if g is not None:
g = g.unsqueeze(0)
c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
return c.squeeze(0).transpose(1, 0)
class ResidualBlock(torch.nn.Module):
"""Residual block module in HiFiGAN."""
def __init__(
self,
kernel_size: int = 3,
channels: int = 512,
dilations: List[int] = [1, 3, 5],
bias: bool = True,
use_additional_convs: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
channels (int): Number of channels for convolution layer.
dilations (List[int]): List of dilation factors.
use_additional_convs (bool): Whether to use additional convolution layers.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
"""
super().__init__()
self.use_additional_convs = use_additional_convs
self.convs1 = torch.nn.ModuleList()
if use_additional_convs:
self.convs2 = torch.nn.ModuleList()
assert kernel_size % 2 == 1, "Kernel size must be odd number."
for dilation in dilations:
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
bias=bias,
padding=(kernel_size - 1) // 2 * dilation,
),
)
]
if use_additional_convs:
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
bias=bias,
padding=(kernel_size - 1) // 2,
),
)
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
Returns:
Tensor: Output tensor (B, channels, T).
"""
for idx in range(len(self.convs1)):
xt = self.convs1[idx](x)
if self.use_additional_convs:
xt = self.convs2[idx](xt)
x = xt + x
return x
class HiFiGANPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN period discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
period: int = 3,
kernel_sizes: List[int] = [5, 3],
channels: int = 32,
downsample_scales: List[int] = [3, 3, 3, 3, 1],
max_downsample_channels: int = 1024,
bias: bool = True,
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initialize HiFiGANPeriodDiscriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
period (int): Period.
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
layer.
channels (int): Number of initial channels.
downsample_scales (List[int]): List of downsampling scales.
max_downsample_channels (int): Number of maximum downsampling channels.
use_additional_convs (bool): Whether to use additional conv layers in
residual blocks.
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm.
If set to true, it will be applied to all of the conv layers.
"""
super().__init__()
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
self.period = period
self.convs = torch.nn.ModuleList()
in_chs = in_channels
out_chs = channels
for downsample_scale in downsample_scales:
self.convs += [
torch.nn.Sequential(
torch.nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Use downsample_scale + 1?
out_chs = min(out_chs * 4, max_downsample_channels)
self.output_conv = torch.nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
if use_spectral_norm:
self.apply_spectral_norm()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
Returns:
list: List of each layer's tensors.
"""
# transform 1d to 2d -> (B, C, T/P, P)
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t += n_pad
x = x.view(b, c, t // self.period, self.period)
# forward conv
outs = []
for layer in self.convs:
x = layer(x)
outs += [x]
x = self.output_conv(x)
x = torch.flatten(x, 1, -1)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN multi-period discriminator module."""
def __init__(
self,
periods: List[int] = [2, 3, 5, 7, 11],
discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
):
"""Initialize HiFiGANMultiPeriodDiscriminator module.
Args:
periods (List[int]): List of periods.
discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
discriminator module. The period parameter will be overwritten.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params["period"] = period
self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List: List of list of each discriminator outputs, which consists of each
layer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
return outs
class HiFiGANScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN scale discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_sizes: List[int] = [15, 41, 5, 3],
channels: int = 128,
max_downsample_channels: int = 1024,
max_groups: int = 16,
bias: int = True,
downsample_scales: List[int] = [2, 2, 4, 4, 1],
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initilize HiFiGAN scale discriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_sizes (List[int]): List of four kernel sizes. The first will be used
for the first conv layer, and the second is for downsampling part, and
the remaining two are for the last two output layers.
channels (int): Initial number of channels for conv layer.
max_downsample_channels (int): Maximum number of channels for downsampling
layers.
bias (bool): Whether to add bias parameter in convolution layers.
downsample_scales (List[int]): List of downsampling scales.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
will be applied to all of the conv layers.
"""
super().__init__()
self.layers = torch.nn.ModuleList()
# check kernel size is valid
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
# add first layer
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_channels,
channels,
# NOTE(kan-bayashi): Use always the same kernel size
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
# add downsample layers
in_chs = channels
out_chs = channels
# NOTE(kan-bayashi): Remove hard coding?
groups = 4
for downsample_scale in downsample_scales:
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Remove hard coding?
out_chs = min(in_chs * 2, max_downsample_channels)
# NOTE(kan-bayashi): Remove hard coding?
groups = min(groups * 4, max_groups)
# add final layers
out_chs = min(in_chs * 2, max_downsample_channels)
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
self.layers += [
torch.nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
),
]
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
self.use_weight_norm = use_weight_norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
self.use_spectral_norm = use_spectral_norm
if use_spectral_norm:
self.apply_spectral_norm()
# backward compatibility
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[Tensor]: List of output tensors of each layer.
"""
outs = []
for f in self.layers:
x = f(x)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv1d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)
def remove_weight_norm(self):
"""Remove weight normalization module from all of the layers."""
def _remove_weight_norm(m):
try:
logging.debug(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def remove_spectral_norm(self):
"""Remove spectral normalization module from all of the layers."""
def _remove_spectral_norm(m):
try:
logging.debug(f"Spectral norm is removed from {m}.")
torch.nn.utils.remove_spectral_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_spectral_norm)
def _load_state_dict_pre_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Fix the compatibility of weight / spectral normalization issue.
Some pretrained models are trained with configs that use weight / spectral
normalization, but actually, the norm is not applied. This causes the mismatch
of the parameters with configs. To solve this issue, when parameter mismatch
happens in loading pretrained model, we remove the norm from the current model.
See also:
- https://github.com/espnet/espnet/pull/5240
- https://github.com/espnet/espnet/pull/5249
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
"""
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
if self.use_weight_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems weight norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_weight_norm()
self.use_weight_norm = False
for k in current_module_keys:
if k.endswith("weight_g") or k.endswith("weight_v"):
del state_dict[k]
if self.use_spectral_norm and any(
[k.endswith("weight") for k in current_module_keys]
):
logging.warning(
"It seems spectral norm is not applied in the pretrained model but the"
" current model uses it. To keep the compatibility, we remove the norm"
" from the current model. This may cause unexpected behavior due to the"
" parameter mismatch in finetuning. To avoid this issue, please change"
" the following parameters in config to false:\n"
" - discriminator_params.follow_official_norm\n"
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
"\n"
"See also:\n"
" - https://github.com/espnet/espnet/pull/5240\n"
" - https://github.com/espnet/espnet/pull/5249"
)
self.remove_spectral_norm()
self.use_spectral_norm = False
for k in current_module_keys:
if (
k.endswith("weight_u")
or k.endswith("weight_v")
or k.endswith("weight_orig")
):
del state_dict[k]
class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale discriminator module."""
def __init__(
self,
scales: int = 3,
downsample_pooling: str = "AvgPool1d",
# follow the official implementation setting
downsample_pooling_params: Dict[str, Any] = {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm: bool = False,
):
"""Initilize HiFiGAN multi-scale discriminator module.
Args:
scales (int): Number of multi-scales.
downsample_pooling (str): Pooling module name for downsampling of the
inputs.
downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
module.
discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
discriminator module.
follow_official_norm (bool): Whether to follow the norm setting of the
official implementaion. The first discriminator uses spectral norm
and the other discriminators use weight norm.
"""
super().__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
if i == 0:
params["use_weight_norm"] = False
params["use_spectral_norm"] = True
else:
params["use_weight_norm"] = True
params["use_spectral_norm"] = False
self.discriminators += [HiFiGANScaleDiscriminator(**params)]
self.pooling = None
if scales > 1:
self.pooling = getattr(torch.nn, downsample_pooling)(
**downsample_pooling_params
)
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[List[torch.Tensor]]: List of list of each discriminator outputs,
which consists of eachlayer output tensors.
"""
outs = []
for f in self.discriminators:
outs += [f(x)]
if self.pooling is not None:
x = self.pooling(x)
return outs
class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
"""HiFi-GAN multi-scale + multi-period discriminator module."""
def __init__(
self,
# Multi-scale discriminator related
scales: int = 3,
scale_downsample_pooling: str = "AvgPool1d",
scale_downsample_pooling_params: Dict[str, Any] = {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
scale_discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
},
follow_official_norm: bool = True,
# Multi-period discriminator related
periods: List[int] = [2, 3, 5, 7, 11],
period_discriminator_params: Dict[str, Any] = {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.1},
"use_weight_norm": True,
"use_spectral_norm": False,
},
):
"""Initilize HiFiGAN multi-scale + multi-period discriminator module.
Args:
scales (int): Number of multi-scales.
scale_downsample_pooling (str): Pooling module name for downsampling of the
inputs.
scale_downsample_pooling_params (dict): Parameters for the above pooling
module.
scale_discriminator_params (dict): Parameters for hifi-gan scale
discriminator module.
follow_official_norm (bool): Whether to follow the norm setting of the
official implementaion. The first discriminator uses spectral norm and
the other discriminators use weight norm.
periods (list): List of periods.
period_discriminator_params (dict): Parameters for hifi-gan period
discriminator module. The period parameter will be overwritten.
"""
super().__init__()
self.msd = HiFiGANMultiScaleDiscriminator(
scales=scales,
downsample_pooling=scale_downsample_pooling,
downsample_pooling_params=scale_downsample_pooling_params,
discriminator_params=scale_discriminator_params,
follow_official_norm=follow_official_norm,
)
self.mpd = HiFiGANMultiPeriodDiscriminator(
periods=periods,
discriminator_params=period_discriminator_params,
)
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[List[Tensor]]: List of list of each discriminator outputs,
which consists of each layer output tensors. Multi scale and
multi period ones are concatenated.
"""
msd_outs = self.msd(x)
mpd_outs = self.mpd(x)
return msd_outs + mpd_outs

View File

@ -1,336 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""HiFiGAN-related loss modules.
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
from typing import List, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize GeneratorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(
self,
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Calcualate generator adversarial loss.
Args:
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs..
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers: bool = True,
average_by_discriminators: bool = True,
include_final_outputs: bool = False,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
sampling_rate: int = 22050,
frame_length: int = 1024, # in samples
frame_shift: int = 256, # in samples
n_mels: int = 80,
use_fft_mag: bool = True,
):
super().__init__()
self.wav_to_mel = Wav2LogFilterBank(
sampling_rate=sampling_rate,
frame_length=frame_length / sampling_rate, # in second
frame_shift=frame_shift / sampling_rate, # in second
use_fft_mag=use_fft_mag,
num_filters=n_mels,
)
def forward(
self,
y_hat: torch.Tensor,
y: torch.Tensor,
return_mel: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated waveform tensor (B, 1, T).
y (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
mel = self.wav_to_mel(y.squeeze(1))
mel_loss = F.l1_loss(mel_hat, mel)
if return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
"""VITS-related loss modules.
This code is based on https://github.com/jaywalnut310/vits.
"""
class KLDivergenceLoss(torch.nn.Module):
"""KL divergence loss."""
def forward(
self,
z_p: torch.Tensor,
logs_q: torch.Tensor,
m_p: torch.Tensor,
logs_p: torch.Tensor,
z_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
kl = torch.sum(kl * z_mask)
loss = kl / torch.sum(z_mask)
return loss
class KLDivergenceLossWithoutFlow(torch.nn.Module):
"""KL divergence loss without flow."""
def forward(
self,
m_q: torch.Tensor,
logs_q: torch.Tensor,
m_p: torch.Tensor,
logs_p: torch.Tensor,
) -> torch.Tensor:
"""Calculate KL divergence loss without flow.
Args:
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
"""
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
prior_norm = D.Normal(m_p, torch.exp(logs_p))
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
return loss

View File

@ -1,117 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Posterior encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional, Tuple
import torch
from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):
"""Posterior encoder module in VITS.
This is a module of posterior encoder described in `Conditional Variational
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int = 513,
out_channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 5,
layers: int = 16,
stacks: int = 1,
base_dilation: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
bias: bool = True,
use_weight_norm: bool = True,
):
"""Initilialize PosteriorEncoder module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size in WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of repeat stacking of WaveNet.
base_dilation (int): Base dilation factor.
global_channels (int): Number of global conditioning channels.
dropout_rate (float): Dropout rate.
bias (bool): Whether to use bias parameters in conv.
use_weight_norm (bool): Whether to apply weight norm.
"""
super().__init__()
# define modules
self.input_conv = Conv1d(in_channels, hidden_channels, 1)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
def forward(
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_feats).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
Tensor: Projected mean tensor (B, out_channels, T_feats).
Tensor: Projected scale tensor (B, out_channels, T_feats).
Tensor: Mask tensor for input tensor (B, 1, T_feats).
"""
x_mask = (
(~make_pad_mask(x_lengths))
.unsqueeze(1)
.to(
dtype=x.dtype,
device=x.device,
)
)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask

View File

@ -1,229 +0,0 @@
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Residual affine coupling modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional, Tuple, Union
import torch
from flow import FlipFlow
from wavenet import WaveNet
class ResidualAffineCouplingBlock(torch.nn.Module):
"""Residual affine coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
flows: int = 4,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 4,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initilize ResidualAffineCouplingBlock module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
flows (int): Number of flows.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
super().__init__()
self.flows = torch.nn.ModuleList()
for i in range(flows):
self.flows += [
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean,
)
]
self.flows += [FlipFlow()]
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
"""
if not inverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, inverse=inverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, inverse=inverse)
return x
class ResidualAffineCouplingLayer(torch.nn.Module):
"""Residual affine coupling layer."""
def __init__(
self,
in_channels: int = 192,
hidden_channels: int = 192,
kernel_size: int = 5,
base_dilation: int = 1,
layers: int = 5,
stacks: int = 1,
global_channels: int = -1,
dropout_rate: float = 0.0,
use_weight_norm: bool = True,
bias: bool = True,
use_only_mean: bool = True,
):
"""Initialzie ResidualAffineCouplingLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
# define modules
self.input_conv = torch.nn.Conv1d(
self.half_channels,
hidden_channels,
1,
)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True,
)
if use_only_mean:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels,
1,
)
else:
self.proj = torch.nn.Conv1d(
hidden_channels,
self.half_channels * 2,
1,
)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
g: Optional[torch.Tensor] = None,
inverse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(x.size(1) // 2, dim=1)
h = self.input_conv(xa) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = stats.split(stats.size(1) // 2, dim=1)
else:
m = stats
logs = torch.zeros_like(m)
if not inverse:
xb = m + xb * torch.exp(logs) * x_mask
x = torch.cat([xa, xb], 1)
logdet = torch.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * torch.exp(-logs) * x_mask
x = torch.cat([xa, xb], 1)
return x

View File

@ -1,684 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on
- https://github.com/jaywalnut310/vits
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
"""
import copy
import math
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import is_jit_tracing, make_pad_mask
class TextEncoder(torch.nn.Module):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
super().__init__()
self.d_model = d_model
# define modules
self.emb = torch.nn.Embedding(vocabs, d_model)
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
# We use conformer as text encoder
self.encoder = Transformer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
num_layers=num_layers,
dropout=dropout,
)
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
# (B, T_text, embed_dim)
x = self.emb(x) * math.sqrt(self.d_model)
assert x.size(1) == x_lengths.max().item()
# (B, T_text)
pad_mask = make_pad_mask(x_lengths)
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
non_pad_mask = (~pad_mask).unsqueeze(1)
stats = self.proj(x) * non_pad_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
return x, m, logs, non_pad_mask
class Transformer(nn.Module):
"""
Args:
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
def __init__(
self,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
) -> None:
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
dropout=dropout,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
self.after_norm = nn.LayerNorm(d_model)
def forward(
self, x: Tensor, key_padding_mask: Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
lengths:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
"""
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.after_norm(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x
class TransformerEncoderLayer(nn.Module):
"""
TransformerEncoderLayer is made up of self-attn and feedforward.
Args:
d_model: the number of expected features in the input.
num_heads: the number of heads in the multi-head attention models.
dim_feedforward: the dimension of the feed-forward network model.
dropout: the dropout value (default=0.1).
"""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
cnn_module_kernel: int,
dropout: float = 0.1,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.self_attn = RelPositionMultiheadAttention(
d_model, num_heads, dropout=dropout
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.ff_scale = 0.5
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the transformer encoder layer.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(
self.feed_forward_macaron(self.norm_ff_macaron(src))
)
# multi-head self-attention module
src_attn = self.self_attn(
self.norm_mha(src),
pos_emb=pos_emb,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout(src_attn)
# convolution module
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
# feed-forward module
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
src = self.norm_final(src)
return src
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer class.
num_layers: the number of sub-encoder-layers in the encoder.
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
output = src
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
key_padding_mask=key_padding_mask,
)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
x_size = x.size(1)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x_size * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size, self.d_model)
pe_negative = torch.zeros(x_size, self.d_model)
position = torch.arange(0, x_size, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, seq_len, 2*seq_len-1).
Returns:
Tensor: tensor of shape (batch, head, seq_len, seq_len)
"""
(batch_size, num_heads, seq_len, n) = x.shape
if not is_jit_tracing():
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
if is_jit_tracing():
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, seq_len, seq_len),
(batch_stride, head_stride, time_stride - n_stride, n_stride),
storage_offset=n_stride * (seq_len - 1),
)
def forward(
self,
x: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x: Input tensor of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
Its shape is (batch_size, seq_len).
Outputs:
A tensor of shape (seq_len, batch_size, embed_dim).
"""
seq_len, batch_size, _ = x.shape
scaling = float(self.head_dim) ** -0.5
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
v = (
v.contiguous()
.view(seq_len, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
p = self.linear_pos(pos_emb).view(
pos_emb.size(0), -1, self.num_heads, self.head_dim
)
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
# (batch_size, num_head, seq_len, head_dim)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p
) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(
matrix_bd
) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
attn_output_weights = attn_output_weights.view(
batch_size, self.num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=self.dropout, training=self.training
)
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (
batch_size * self.num_heads,
seq_len,
self.head_dim,
)
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def _test_text_encoder():
vocabs = 500
d_model = 192
batch_size = 5
seq_len = 100
m = TextEncoder(vocabs=vocabs, d_model=d_model)
x, m, logs, mask = m(
x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
x_lengths=torch.full((batch_size,), seq_len),
)
print(x.shape, m.shape, logs.shape, mask.shape)
if __name__ == "__main__":
_test_text_encoder()

View File

@ -1,108 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
import g2p_en
import tacotron_cleaner.cleaners
from utils import intersperse
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
self.token2id[token] = id
self.blank_id = self.token2id["<blk>"]
self.oov_id = self.token2id["<unk>"]
self.vocab_size = len(self.token2id)
self.g2p = g2p_en.G2p()
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for text in texts:
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = self.g2p(text)
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t in self.token2id:
token_ids.append(self.token2id[t])
else:
token_ids.append(self.oov_id)
if intersperse_blank:
token_ids = intersperse(token_ids, self.blank_id)
token_ids_list.append(token_ids)
return token_ids_list