mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
removed redundant files
This commit is contained in:
parent
0f051f5518
commit
9931694455
@ -1,194 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Stochastic duration predictor modules in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from flow import (
|
||||
ConvFlow,
|
||||
DilatedDepthSeparableConv,
|
||||
ElementwiseAffineFlow,
|
||||
FlipFlow,
|
||||
LogFlow,
|
||||
)
|
||||
|
||||
|
||||
class StochasticDurationPredictor(torch.nn.Module):
|
||||
"""Stochastic duration predictor module.
|
||||
|
||||
This is a module of stochastic duration predictor described in `Conditional
|
||||
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 192,
|
||||
kernel_size: int = 3,
|
||||
dropout_rate: float = 0.5,
|
||||
flows: int = 4,
|
||||
dds_conv_layers: int = 3,
|
||||
global_channels: int = -1,
|
||||
):
|
||||
"""Initialize StochasticDurationPredictor module.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
kernel_size (int): Kernel size.
|
||||
dropout_rate (float): Dropout rate.
|
||||
flows (int): Number of flows.
|
||||
dds_conv_layers (int): Number of conv layers in DDS conv.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.pre = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.dds = DilatedDepthSeparableConv(
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
self.proj = torch.nn.Conv1d(channels, channels, 1)
|
||||
|
||||
self.log_flow = LogFlow()
|
||||
self.flows = torch.nn.ModuleList()
|
||||
self.flows += [ElementwiseAffineFlow(2)]
|
||||
for i in range(flows):
|
||||
self.flows += [
|
||||
ConvFlow(
|
||||
2,
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers,
|
||||
)
|
||||
]
|
||||
self.flows += [FlipFlow()]
|
||||
|
||||
self.post_pre = torch.nn.Conv1d(1, channels, 1)
|
||||
self.post_dds = DilatedDepthSeparableConv(
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
self.post_proj = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.post_flows = torch.nn.ModuleList()
|
||||
self.post_flows += [ElementwiseAffineFlow(2)]
|
||||
for i in range(flows):
|
||||
self.post_flows += [
|
||||
ConvFlow(
|
||||
2,
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers,
|
||||
)
|
||||
]
|
||||
self.post_flows += [FlipFlow()]
|
||||
|
||||
if global_channels > 0:
|
||||
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
w: Optional[torch.Tensor] = None,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
inverse: bool = False,
|
||||
noise_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T_text).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T_text).
|
||||
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
noise_scale (float): Noise scale value.
|
||||
|
||||
Returns:
|
||||
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
|
||||
If inverse, log-duration tensor (B, 1, T_text).
|
||||
|
||||
"""
|
||||
x = x.detach() # stop gradient
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
x = x + self.global_conv(g.detach()) # stop gradient
|
||||
x = self.dds(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not inverse:
|
||||
assert w is not None, "w must be provided."
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_dds(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (
|
||||
torch.randn(
|
||||
w.size(0),
|
||||
2,
|
||||
w.size(2),
|
||||
).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
z_q = e_q
|
||||
logdet_tot_q = 0.0
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
logdet_tot_q += logdet_q
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
logdet_tot += logdet
|
||||
z = torch.cat([z0, z1], 1)
|
||||
for flow in self.flows:
|
||||
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
return nll + logq # (B,)
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = (
|
||||
torch.randn(
|
||||
x.size(0),
|
||||
2,
|
||||
x.size(2),
|
||||
).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, inverse=inverse)
|
||||
z0, z1 = z.split(1, 1)
|
||||
logw = z0
|
||||
return logw
|
@ -1,312 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Basic Flow modules used in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transform import piecewise_rational_quadratic_transform
|
||||
|
||||
|
||||
class FlipFlow(torch.nn.Module):
|
||||
"""Flip flow module."""
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Flipped tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
|
||||
"""
|
||||
x = torch.flip(x, [1])
|
||||
if not inverse:
|
||||
logdet = x.new_zeros(x.size(0))
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class LogFlow(torch.nn.Module):
|
||||
"""Log flow module."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
inverse: bool = False,
|
||||
eps: float = 1e-5,
|
||||
**kwargs
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
eps (float): Epsilon for log.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
|
||||
"""
|
||||
if not inverse:
|
||||
y = torch.log(torch.clamp_min(x, eps)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffineFlow(torch.nn.Module):
|
||||
"""Elementwise affine flow module."""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
"""Initialize ElementwiseAffineFlow module.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
|
||||
self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
|
||||
"""
|
||||
if not inverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Transpose(torch.nn.Module):
|
||||
"""Transpose module for torch.nn.Sequential()."""
|
||||
|
||||
def __init__(self, dim1: int, dim2: int):
|
||||
"""Initialize Transpose module."""
|
||||
super().__init__()
|
||||
self.dim1 = dim1
|
||||
self.dim2 = dim2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Transpose."""
|
||||
return x.transpose(self.dim1, self.dim2)
|
||||
|
||||
|
||||
class DilatedDepthSeparableConv(torch.nn.Module):
|
||||
"""Dilated depth-separable conv module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
layers: int,
|
||||
dropout_rate: float = 0.0,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
"""Initialize DilatedDepthSeparableConv module.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
kernel_size (int): Kernel size.
|
||||
layers (int): Number of layers.
|
||||
dropout_rate (float): Dropout rate.
|
||||
eps (float): Epsilon for layer norm.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.convs = torch.nn.ModuleList()
|
||||
for i in range(layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs += [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
1,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
torch.nn.LayerNorm(
|
||||
channels,
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
),
|
||||
Transpose(1, 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
|
||||
"""
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for f in self.convs:
|
||||
y = f(x * x_mask)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class ConvFlow(torch.nn.Module):
|
||||
"""Convolutional flow module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
layers: int,
|
||||
bins: int = 10,
|
||||
tail_bound: float = 5.0,
|
||||
):
|
||||
"""Initialize ConvFlow module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size.
|
||||
layers (int): Number of layers.
|
||||
bins (int): Number of bins.
|
||||
tail_bound (float): Tail bound value.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.half_channels = in_channels // 2
|
||||
self.hidden_channels = hidden_channels
|
||||
self.bins = bins
|
||||
self.tail_bound = tail_bound
|
||||
|
||||
self.input_conv = torch.nn.Conv1d(
|
||||
self.half_channels,
|
||||
hidden_channels,
|
||||
1,
|
||||
)
|
||||
self.dds_conv = DilatedDepthSeparableConv(
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
layers,
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
self.proj = torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
self.half_channels * (bins * 3 - 1),
|
||||
1,
|
||||
)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
inverse: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
x_mask (Tensor): Mask tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
|
||||
"""
|
||||
xa, xb = x.split(x.size(1) // 2, 1)
|
||||
h = self.input_conv(xa)
|
||||
h = self.dds_conv(h, x_mask, g=g)
|
||||
h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
|
||||
|
||||
b, c, t = xa.shape
|
||||
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
|
||||
|
||||
# TODO(kan-bayashi): Understand this calculation
|
||||
denom = math.sqrt(self.hidden_channels)
|
||||
unnorm_widths = h[..., : self.bins] / denom
|
||||
unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
|
||||
unnorm_derivatives = h[..., 2 * self.bins :]
|
||||
xb, logdet_abs = piecewise_rational_quadratic_transform(
|
||||
xb,
|
||||
unnorm_widths,
|
||||
unnorm_heights,
|
||||
unnorm_derivatives,
|
||||
inverse=inverse,
|
||||
tails="linear",
|
||||
tail_bound=self.tail_bound,
|
||||
)
|
||||
x = torch.cat([xa, xb], 1) * x_mask
|
||||
logdet = torch.sum(logdet_abs * x_mask, [1, 2])
|
||||
if not inverse:
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
@ -1,531 +0,0 @@
|
||||
# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Generator module in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from duration_predictor import StochasticDurationPredictor
|
||||
from hifigan import HiFiGANGenerator
|
||||
from posterior_encoder import PosteriorEncoder
|
||||
from residual_coupling import ResidualAffineCouplingBlock
|
||||
from text_encoder import TextEncoder
|
||||
from utils import get_random_segments
|
||||
|
||||
|
||||
class VITSGenerator(torch.nn.Module):
|
||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||
with Adversarial Learning for End-to-End Text-to-Speech`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocabs: int,
|
||||
aux_channels: int = 513,
|
||||
hidden_channels: int = 192,
|
||||
spks: Optional[int] = None,
|
||||
langs: Optional[int] = None,
|
||||
spk_embed_dim: Optional[int] = None,
|
||||
global_channels: int = -1,
|
||||
segment_size: int = 32,
|
||||
text_encoder_attention_heads: int = 2,
|
||||
text_encoder_ffn_expand: int = 4,
|
||||
text_encoder_cnn_module_kernel: int = 5,
|
||||
text_encoder_blocks: int = 6,
|
||||
text_encoder_dropout_rate: float = 0.1,
|
||||
decoder_kernel_size: int = 7,
|
||||
decoder_channels: int = 512,
|
||||
decoder_upsample_scales: List[int] = [8, 8, 2, 2],
|
||||
decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
|
||||
decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||
decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
use_weight_norm_in_decoder: bool = True,
|
||||
posterior_encoder_kernel_size: int = 5,
|
||||
posterior_encoder_layers: int = 16,
|
||||
posterior_encoder_stacks: int = 1,
|
||||
posterior_encoder_base_dilation: int = 1,
|
||||
posterior_encoder_dropout_rate: float = 0.0,
|
||||
use_weight_norm_in_posterior_encoder: bool = True,
|
||||
flow_flows: int = 4,
|
||||
flow_kernel_size: int = 5,
|
||||
flow_base_dilation: int = 1,
|
||||
flow_layers: int = 4,
|
||||
flow_dropout_rate: float = 0.0,
|
||||
use_weight_norm_in_flow: bool = True,
|
||||
use_only_mean_in_flow: bool = True,
|
||||
stochastic_duration_predictor_kernel_size: int = 3,
|
||||
stochastic_duration_predictor_dropout_rate: float = 0.5,
|
||||
stochastic_duration_predictor_flows: int = 4,
|
||||
stochastic_duration_predictor_dds_conv_layers: int = 3,
|
||||
):
|
||||
"""Initialize VITS generator module.
|
||||
|
||||
Args:
|
||||
vocabs (int): Input vocabulary size.
|
||||
aux_channels (int): Number of acoustic feature channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
|
||||
sids will be provided as the input and use sid embedding layer.
|
||||
langs (Optional[int]): Number of languages. If set to > 1, assume that the
|
||||
lids will be provided as the input and use sid embedding layer.
|
||||
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
|
||||
assume that spembs will be provided as the input.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
segment_size (int): Segment size for decoder.
|
||||
text_encoder_attention_heads (int): Number of heads in conformer block
|
||||
of text encoder.
|
||||
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
|
||||
of text encoder.
|
||||
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
|
||||
text_encoder_blocks (int): Number of conformer blocks in text encoder.
|
||||
text_encoder_dropout_rate (float): Dropout rate in conformer block of
|
||||
text encoder.
|
||||
decoder_kernel_size (int): Decoder kernel size.
|
||||
decoder_channels (int): Number of decoder initial channels.
|
||||
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
|
||||
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
|
||||
upsampling layers in decoder.
|
||||
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
|
||||
in decoder.
|
||||
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
|
||||
resblocks in decoder.
|
||||
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
|
||||
decoder.
|
||||
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
|
||||
posterior_encoder_layers (int): Number of layers of posterior encoder.
|
||||
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
|
||||
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
|
||||
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
|
||||
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
|
||||
normalization in posterior encoder.
|
||||
flow_flows (int): Number of flows in flow.
|
||||
flow_kernel_size (int): Kernel size in flow.
|
||||
flow_base_dilation (int): Base dilation in flow.
|
||||
flow_layers (int): Number of layers in flow.
|
||||
flow_dropout_rate (float): Dropout rate in flow
|
||||
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
|
||||
flow.
|
||||
use_only_mean_in_flow (bool): Whether to use only mean in flow.
|
||||
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
|
||||
duration predictor.
|
||||
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
|
||||
stochastic duration predictor.
|
||||
stochastic_duration_predictor_flows (int): Number of flows in stochastic
|
||||
duration predictor.
|
||||
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
|
||||
layers in stochastic duration predictor.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.segment_size = segment_size
|
||||
self.text_encoder = TextEncoder(
|
||||
vocabs=vocabs,
|
||||
d_model=hidden_channels,
|
||||
num_heads=text_encoder_attention_heads,
|
||||
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
|
||||
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
||||
num_layers=text_encoder_blocks,
|
||||
dropout=text_encoder_dropout_rate,
|
||||
)
|
||||
self.decoder = HiFiGANGenerator(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=1,
|
||||
channels=decoder_channels,
|
||||
global_channels=global_channels,
|
||||
kernel_size=decoder_kernel_size,
|
||||
upsample_scales=decoder_upsample_scales,
|
||||
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
|
||||
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
|
||||
resblock_dilations=decoder_resblock_dilations,
|
||||
use_weight_norm=use_weight_norm_in_decoder,
|
||||
)
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
in_channels=aux_channels,
|
||||
out_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=posterior_encoder_kernel_size,
|
||||
layers=posterior_encoder_layers,
|
||||
stacks=posterior_encoder_stacks,
|
||||
base_dilation=posterior_encoder_base_dilation,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=posterior_encoder_dropout_rate,
|
||||
use_weight_norm=use_weight_norm_in_posterior_encoder,
|
||||
)
|
||||
self.flow = ResidualAffineCouplingBlock(
|
||||
in_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
flows=flow_flows,
|
||||
kernel_size=flow_kernel_size,
|
||||
base_dilation=flow_base_dilation,
|
||||
layers=flow_layers,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=flow_dropout_rate,
|
||||
use_weight_norm=use_weight_norm_in_flow,
|
||||
use_only_mean=use_only_mean_in_flow,
|
||||
)
|
||||
# TODO(kan-bayashi): Add deterministic version as an option
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
channels=hidden_channels,
|
||||
kernel_size=stochastic_duration_predictor_kernel_size,
|
||||
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
||||
flows=stochastic_duration_predictor_flows,
|
||||
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
|
||||
global_channels=global_channels,
|
||||
)
|
||||
|
||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||
self.spks = None
|
||||
if spks is not None and spks > 1:
|
||||
assert global_channels > 0
|
||||
self.spks = spks
|
||||
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
||||
self.spk_embed_dim = None
|
||||
if spk_embed_dim is not None and spk_embed_dim > 0:
|
||||
assert global_channels > 0
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
|
||||
self.langs = None
|
||||
if langs is not None and langs > 1:
|
||||
assert global_channels > 0
|
||||
self.langs = langs
|
||||
self.lang_emb = torch.nn.Embedding(langs, global_channels)
|
||||
|
||||
# delayed import
|
||||
from monotonic_align import maximum_path
|
||||
|
||||
self.maximum_path = maximum_path
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
feats: torch.Tensor,
|
||||
feats_lengths: torch.Tensor,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
],
|
||||
]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
|
||||
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
|
||||
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
|
||||
Tensor: Segments start index tensor (B,).
|
||||
Tensor: Text mask tensor (B, 1, T_text).
|
||||
Tensor: Feature mask tensor (B, 1, T_feats).
|
||||
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
|
||||
- Tensor: Flow hidden representation (B, H, T_feats).
|
||||
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
|
||||
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
|
||||
- Tensor: Posterior encoder projected mean (B, H, T_feats).
|
||||
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
||||
|
||||
"""
|
||||
# forward text encoder
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||
|
||||
# calculate global conditioning
|
||||
g = None
|
||||
if self.spks is not None:
|
||||
# speaker one-hot vector embedding: (B, global_channels, 1)
|
||||
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
||||
if self.spk_embed_dim is not None:
|
||||
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
|
||||
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
if self.langs is not None:
|
||||
# language one-hot vector embedding: (B, global_channels, 1)
|
||||
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
|
||||
# forward posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||
|
||||
# forward flow
|
||||
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
|
||||
|
||||
# monotonic alignment search
|
||||
with torch.no_grad():
|
||||
# negative cross-entropy
|
||||
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_1 = torch.sum(
|
||||
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||
[1],
|
||||
keepdim=True,
|
||||
)
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_2 = torch.matmul(
|
||||
-0.5 * (z_p**2).transpose(1, 2),
|
||||
s_p_sq_r,
|
||||
)
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_3 = torch.matmul(
|
||||
z_p.transpose(1, 2),
|
||||
(m_p * s_p_sq_r),
|
||||
)
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_4 = torch.sum(
|
||||
-0.5 * (m_p**2) * s_p_sq_r,
|
||||
[1],
|
||||
keepdim=True,
|
||||
)
|
||||
# (B, T_feats, T_text)
|
||||
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||
# (B, 1, T_feats, T_text)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||
attn = (
|
||||
self.maximum_path(
|
||||
neg_x_ent,
|
||||
attn_mask.squeeze(1),
|
||||
)
|
||||
.unsqueeze(1)
|
||||
.detach()
|
||||
)
|
||||
|
||||
# forward duration predictor
|
||||
w = attn.sum(2) # (B, 1, T_text)
|
||||
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
||||
dur_nll = dur_nll / torch.sum(x_mask)
|
||||
|
||||
# expand the length to match with the feature sequence
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
# get random segments
|
||||
z_segments, z_start_idxs = get_random_segments(
|
||||
z,
|
||||
feats_lengths,
|
||||
self.segment_size,
|
||||
)
|
||||
|
||||
# forward decoder with random segments
|
||||
wav = self.decoder(z_segments, g=g)
|
||||
|
||||
return (
|
||||
wav,
|
||||
dur_nll,
|
||||
attn,
|
||||
z_start_idxs,
|
||||
x_mask,
|
||||
y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
feats: Optional[torch.Tensor] = None,
|
||||
feats_lengths: Optional[torch.Tensor] = None,
|
||||
sids: Optional[torch.Tensor] = None,
|
||||
spembs: Optional[torch.Tensor] = None,
|
||||
lids: Optional[torch.Tensor] = None,
|
||||
dur: Optional[torch.Tensor] = None,
|
||||
noise_scale: float = 0.667,
|
||||
noise_scale_dur: float = 0.8,
|
||||
alpha: float = 1.0,
|
||||
max_len: Optional[int] = None,
|
||||
use_teacher_forcing: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference.
|
||||
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (B, T_text,).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
|
||||
skip the prediction of durations (i.e., teacher forcing).
|
||||
noise_scale (float): Noise scale parameter for flow.
|
||||
noise_scale_dur (float): Noise scale parameter for duration predictor.
|
||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||
max_len (Optional[int]): Maximum length of acoustic feature sequence.
|
||||
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||
|
||||
Returns:
|
||||
Tensor: Generated waveform tensor (B, T_wav).
|
||||
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
|
||||
Tensor: Duration tensor (B, T_text).
|
||||
|
||||
"""
|
||||
# encoder
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||
x_mask = x_mask.to(x.dtype)
|
||||
g = None
|
||||
if self.spks is not None:
|
||||
# (B, global_channels, 1)
|
||||
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
|
||||
if self.spk_embed_dim is not None:
|
||||
# (B, global_channels, 1)
|
||||
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
if self.langs is not None:
|
||||
# (B, global_channels, 1)
|
||||
g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
|
||||
if use_teacher_forcing:
|
||||
# forward posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
|
||||
|
||||
# forward flow
|
||||
z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
|
||||
|
||||
# monotonic alignment search
|
||||
s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_1 = torch.sum(
|
||||
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||
[1],
|
||||
keepdim=True,
|
||||
)
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_2 = torch.matmul(
|
||||
-0.5 * (z_p**2).transpose(1, 2),
|
||||
s_p_sq_r,
|
||||
)
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_3 = torch.matmul(
|
||||
z_p.transpose(1, 2),
|
||||
(m_p * s_p_sq_r),
|
||||
)
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_4 = torch.sum(
|
||||
-0.5 * (m_p**2) * s_p_sq_r,
|
||||
[1],
|
||||
keepdim=True,
|
||||
)
|
||||
# (B, T_feats, T_text)
|
||||
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||
# (B, 1, T_feats, T_text)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||
attn = self.maximum_path(
|
||||
neg_x_ent,
|
||||
attn_mask.squeeze(1),
|
||||
).unsqueeze(1)
|
||||
dur = attn.sum(2) # (B, 1, T_text)
|
||||
|
||||
# forward decoder with random segments
|
||||
wav = self.decoder(z * y_mask, g=g)
|
||||
else:
|
||||
# duration
|
||||
if dur is None:
|
||||
logw = self.duration_predictor(
|
||||
x,
|
||||
x_mask,
|
||||
g=g,
|
||||
inverse=True,
|
||||
noise_scale=noise_scale_dur,
|
||||
)
|
||||
w = torch.exp(logw) * x_mask * alpha
|
||||
dur = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
|
||||
y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
|
||||
y_mask = y_mask.to(x.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = self._generate_path(dur, attn_mask)
|
||||
|
||||
# expand the length to match with the feature sequence
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
m_p = torch.matmul(
|
||||
attn.squeeze(1),
|
||||
m_p.transpose(1, 2),
|
||||
).transpose(1, 2)
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
logs_p = torch.matmul(
|
||||
attn.squeeze(1),
|
||||
logs_p.transpose(1, 2),
|
||||
).transpose(1, 2)
|
||||
|
||||
# decoder
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, inverse=True)
|
||||
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
|
||||
|
||||
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
|
||||
|
||||
def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate path a.k.a. monotonic attention.
|
||||
|
||||
Args:
|
||||
dur (Tensor): Duration tensor (B, 1, T_text).
|
||||
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
|
||||
|
||||
Returns:
|
||||
Tensor: Path tensor (B, 1, T_feats, T_text).
|
||||
|
||||
"""
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_dur = torch.cumsum(dur, -1)
|
||||
cum_dur_flat = cum_dur.view(b * t_x)
|
||||
path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
|
||||
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
||||
# path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
|
||||
path = path.view(b, t_x, t_y).to(dtype=torch.float)
|
||||
# path will be like (t_x = 3, t_y = 5):
|
||||
# [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
||||
# [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
||||
# [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
||||
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
||||
# path = path.to(dtype=mask.dtype)
|
||||
return path.unsqueeze(1).transpose(2, 3) * mask
|
@ -1,933 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""HiFi-GAN Modules.
|
||||
|
||||
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class HiFiGANGenerator(torch.nn.Module):
|
||||
"""HiFiGAN generator module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
out_channels: int = 1,
|
||||
channels: int = 512,
|
||||
global_channels: int = -1,
|
||||
kernel_size: int = 7,
|
||||
upsample_scales: List[int] = [8, 8, 2, 2],
|
||||
upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
|
||||
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
use_additional_convs: bool = True,
|
||||
bias: bool = True,
|
||||
nonlinear_activation: str = "LeakyReLU",
|
||||
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||
use_weight_norm: bool = True,
|
||||
):
|
||||
"""Initialize HiFiGANGenerator module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
channels (int): Number of hidden representation channels.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
kernel_size (int): Kernel size of initial and final conv layer.
|
||||
upsample_scales (List[int]): List of upsampling scales.
|
||||
upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
|
||||
resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
|
||||
resblock_dilations (List[List[int]]): List of list of dilations for residual
|
||||
blocks.
|
||||
use_additional_convs (bool): Whether to use additional conv layers in
|
||||
residual blocks.
|
||||
bias (bool): Whether to add bias parameter in convolution layers.
|
||||
nonlinear_activation (str): Activation function module name.
|
||||
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||
function.
|
||||
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||
be applied to all of the conv layers.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# check hyperparameters are valid
|
||||
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
||||
assert len(upsample_scales) == len(upsample_kernel_sizes)
|
||||
assert len(resblock_dilations) == len(resblock_kernel_sizes)
|
||||
|
||||
# define modules
|
||||
self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
|
||||
self.num_upsamples = len(upsample_kernel_sizes)
|
||||
self.num_blocks = len(resblock_kernel_sizes)
|
||||
self.input_conv = torch.nn.Conv1d(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.upsamples = torch.nn.ModuleList()
|
||||
self.blocks = torch.nn.ModuleList()
|
||||
for i in range(len(upsample_kernel_sizes)):
|
||||
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
|
||||
self.upsamples += [
|
||||
torch.nn.Sequential(
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params
|
||||
),
|
||||
torch.nn.ConvTranspose1d(
|
||||
channels // (2**i),
|
||||
channels // (2 ** (i + 1)),
|
||||
upsample_kernel_sizes[i],
|
||||
upsample_scales[i],
|
||||
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
|
||||
output_padding=upsample_scales[i] % 2,
|
||||
),
|
||||
)
|
||||
]
|
||||
for j in range(len(resblock_kernel_sizes)):
|
||||
self.blocks += [
|
||||
ResidualBlock(
|
||||
kernel_size=resblock_kernel_sizes[j],
|
||||
channels=channels // (2 ** (i + 1)),
|
||||
dilations=resblock_dilations[j],
|
||||
bias=bias,
|
||||
use_additional_convs=use_additional_convs,
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
)
|
||||
]
|
||||
self.output_conv = torch.nn.Sequential(
|
||||
# NOTE(kan-bayashi): follow official implementation but why
|
||||
# using different slope parameter here? (0.1 vs. 0.01)
|
||||
torch.nn.LeakyReLU(),
|
||||
torch.nn.Conv1d(
|
||||
channels // (2 ** (i + 1)),
|
||||
out_channels,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
torch.nn.Tanh(),
|
||||
)
|
||||
if global_channels > 0:
|
||||
self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
# reset parameters
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(
|
||||
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
c (Tensor): Input tensor (B, in_channels, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T).
|
||||
|
||||
"""
|
||||
c = self.input_conv(c)
|
||||
if g is not None:
|
||||
c = c + self.global_conv(g)
|
||||
for i in range(self.num_upsamples):
|
||||
c = self.upsamples[i](c)
|
||||
cs = 0.0 # initialize
|
||||
for j in range(self.num_blocks):
|
||||
cs += self.blocks[i * self.num_blocks + j](c)
|
||||
c = cs / self.num_blocks
|
||||
c = self.output_conv(c)
|
||||
|
||||
return c
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters.
|
||||
|
||||
This initialization follows the official implementation manner.
|
||||
https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
|
||||
"""
|
||||
|
||||
def _reset_parameters(m: torch.nn.Module):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
|
||||
m.weight.data.normal_(0.0, 0.01)
|
||||
logging.debug(f"Reset parameters in {m}.")
|
||||
|
||||
self.apply(_reset_parameters)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m: torch.nn.Module):
|
||||
try:
|
||||
logging.debug(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv1d) or isinstance(
|
||||
m, torch.nn.ConvTranspose1d
|
||||
):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
logging.debug(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def inference(
|
||||
self, c: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Perform inference.
|
||||
|
||||
Args:
|
||||
c (torch.Tensor): Input tensor (T, in_channels).
|
||||
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (T ** upsample_factor, out_channels).
|
||||
|
||||
"""
|
||||
if g is not None:
|
||||
g = g.unsqueeze(0)
|
||||
c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
|
||||
return c.squeeze(0).transpose(1, 0)
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in HiFiGAN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int = 3,
|
||||
channels: int = 512,
|
||||
dilations: List[int] = [1, 3, 5],
|
||||
bias: bool = True,
|
||||
use_additional_convs: bool = True,
|
||||
nonlinear_activation: str = "LeakyReLU",
|
||||
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||
):
|
||||
"""Initialize ResidualBlock module.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Kernel size of dilation convolution layer.
|
||||
channels (int): Number of channels for convolution layer.
|
||||
dilations (List[int]): List of dilation factors.
|
||||
use_additional_convs (bool): Whether to use additional convolution layers.
|
||||
bias (bool): Whether to add bias parameter in convolution layers.
|
||||
nonlinear_activation (str): Activation function module name.
|
||||
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||
function.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_additional_convs = use_additional_convs
|
||||
self.convs1 = torch.nn.ModuleList()
|
||||
if use_additional_convs:
|
||||
self.convs2 = torch.nn.ModuleList()
|
||||
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
||||
for dilation in dilations:
|
||||
self.convs1 += [
|
||||
torch.nn.Sequential(
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params
|
||||
),
|
||||
torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
padding=(kernel_size - 1) // 2 * dilation,
|
||||
),
|
||||
)
|
||||
]
|
||||
if use_additional_convs:
|
||||
self.convs2 += [
|
||||
torch.nn.Sequential(
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params
|
||||
),
|
||||
torch.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
bias=bias,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
|
||||
"""
|
||||
for idx in range(len(self.convs1)):
|
||||
xt = self.convs1[idx](x)
|
||||
if self.use_additional_convs:
|
||||
xt = self.convs2[idx](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class HiFiGANPeriodDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN period discriminator module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
period: int = 3,
|
||||
kernel_sizes: List[int] = [5, 3],
|
||||
channels: int = 32,
|
||||
downsample_scales: List[int] = [3, 3, 3, 3, 1],
|
||||
max_downsample_channels: int = 1024,
|
||||
bias: bool = True,
|
||||
nonlinear_activation: str = "LeakyReLU",
|
||||
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||
use_weight_norm: bool = True,
|
||||
use_spectral_norm: bool = False,
|
||||
):
|
||||
"""Initialize HiFiGANPeriodDiscriminator module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
period (int): Period.
|
||||
kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
|
||||
layer.
|
||||
channels (int): Number of initial channels.
|
||||
downsample_scales (List[int]): List of downsampling scales.
|
||||
max_downsample_channels (int): Number of maximum downsampling channels.
|
||||
use_additional_convs (bool): Whether to use additional conv layers in
|
||||
residual blocks.
|
||||
bias (bool): Whether to add bias parameter in convolution layers.
|
||||
nonlinear_activation (str): Activation function module name.
|
||||
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||
function.
|
||||
use_weight_norm (bool): Whether to use weight norm.
|
||||
If set to true, it will be applied to all of the conv layers.
|
||||
use_spectral_norm (bool): Whether to use spectral norm.
|
||||
If set to true, it will be applied to all of the conv layers.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
assert len(kernel_sizes) == 2
|
||||
assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
|
||||
assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
|
||||
|
||||
self.period = period
|
||||
self.convs = torch.nn.ModuleList()
|
||||
in_chs = in_channels
|
||||
out_chs = channels
|
||||
for downsample_scale in downsample_scales:
|
||||
self.convs += [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
(kernel_sizes[0], 1),
|
||||
(downsample_scale, 1),
|
||||
padding=((kernel_sizes[0] - 1) // 2, 0),
|
||||
),
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params
|
||||
),
|
||||
)
|
||||
]
|
||||
in_chs = out_chs
|
||||
# NOTE(kan-bayashi): Use downsample_scale + 1?
|
||||
out_chs = min(out_chs * 4, max_downsample_channels)
|
||||
self.output_conv = torch.nn.Conv2d(
|
||||
out_chs,
|
||||
out_channels,
|
||||
(kernel_sizes[1] - 1, 1),
|
||||
1,
|
||||
padding=((kernel_sizes[1] - 1) // 2, 0),
|
||||
)
|
||||
|
||||
if use_weight_norm and use_spectral_norm:
|
||||
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
# apply spectral norm
|
||||
if use_spectral_norm:
|
||||
self.apply_spectral_norm()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
c (Tensor): Input tensor (B, in_channels, T).
|
||||
|
||||
Returns:
|
||||
list: List of each layer's tensors.
|
||||
|
||||
"""
|
||||
# transform 1d to 2d -> (B, C, T/P, P)
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0:
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t += n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
# forward conv
|
||||
outs = []
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
outs += [x]
|
||||
x = self.output_conv(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
outs += [x]
|
||||
|
||||
return outs
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
logging.debug(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def apply_spectral_norm(self):
|
||||
"""Apply spectral normalization module from all of the layers."""
|
||||
|
||||
def _apply_spectral_norm(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.utils.spectral_norm(m)
|
||||
logging.debug(f"Spectral norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_spectral_norm)
|
||||
|
||||
|
||||
class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
|
||||
"""HiFiGAN multi-period discriminator module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
periods: List[int] = [2, 3, 5, 7, 11],
|
||||
discriminator_params: Dict[str, Any] = {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [5, 3],
|
||||
"channels": 32,
|
||||
"downsample_scales": [3, 3, 3, 3, 1],
|
||||
"max_downsample_channels": 1024,
|
||||
"bias": True,
|
||||
"nonlinear_activation": "LeakyReLU",
|
||||
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||
"use_weight_norm": True,
|
||||
"use_spectral_norm": False,
|
||||
},
|
||||
):
|
||||
"""Initialize HiFiGANMultiPeriodDiscriminator module.
|
||||
|
||||
Args:
|
||||
periods (List[int]): List of periods.
|
||||
discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
|
||||
discriminator module. The period parameter will be overwritten.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.discriminators = torch.nn.ModuleList()
|
||||
for period in periods:
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
params["period"] = period
|
||||
self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List: List of list of each discriminator outputs, which consists of each
|
||||
layer output tensors.
|
||||
|
||||
"""
|
||||
outs = []
|
||||
for f in self.discriminators:
|
||||
outs += [f(x)]
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class HiFiGANScaleDiscriminator(torch.nn.Module):
|
||||
"""HiFi-GAN scale discriminator module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 1,
|
||||
out_channels: int = 1,
|
||||
kernel_sizes: List[int] = [15, 41, 5, 3],
|
||||
channels: int = 128,
|
||||
max_downsample_channels: int = 1024,
|
||||
max_groups: int = 16,
|
||||
bias: int = True,
|
||||
downsample_scales: List[int] = [2, 2, 4, 4, 1],
|
||||
nonlinear_activation: str = "LeakyReLU",
|
||||
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
|
||||
use_weight_norm: bool = True,
|
||||
use_spectral_norm: bool = False,
|
||||
):
|
||||
"""Initilize HiFiGAN scale discriminator module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
kernel_sizes (List[int]): List of four kernel sizes. The first will be used
|
||||
for the first conv layer, and the second is for downsampling part, and
|
||||
the remaining two are for the last two output layers.
|
||||
channels (int): Initial number of channels for conv layer.
|
||||
max_downsample_channels (int): Maximum number of channels for downsampling
|
||||
layers.
|
||||
bias (bool): Whether to add bias parameter in convolution layers.
|
||||
downsample_scales (List[int]): List of downsampling scales.
|
||||
nonlinear_activation (str): Activation function module name.
|
||||
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
|
||||
function.
|
||||
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||
be applied to all of the conv layers.
|
||||
use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
|
||||
will be applied to all of the conv layers.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList()
|
||||
|
||||
# check kernel size is valid
|
||||
assert len(kernel_sizes) == 4
|
||||
for ks in kernel_sizes:
|
||||
assert ks % 2 == 1
|
||||
|
||||
# add first layer
|
||||
self.layers += [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv1d(
|
||||
in_channels,
|
||||
channels,
|
||||
# NOTE(kan-bayashi): Use always the same kernel size
|
||||
kernel_sizes[0],
|
||||
bias=bias,
|
||||
padding=(kernel_sizes[0] - 1) // 2,
|
||||
),
|
||||
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
||||
)
|
||||
]
|
||||
|
||||
# add downsample layers
|
||||
in_chs = channels
|
||||
out_chs = channels
|
||||
# NOTE(kan-bayashi): Remove hard coding?
|
||||
groups = 4
|
||||
for downsample_scale in downsample_scales:
|
||||
self.layers += [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_sizes[1],
|
||||
stride=downsample_scale,
|
||||
padding=(kernel_sizes[1] - 1) // 2,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
),
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params
|
||||
),
|
||||
)
|
||||
]
|
||||
in_chs = out_chs
|
||||
# NOTE(kan-bayashi): Remove hard coding?
|
||||
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||
# NOTE(kan-bayashi): Remove hard coding?
|
||||
groups = min(groups * 4, max_groups)
|
||||
|
||||
# add final layers
|
||||
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||
self.layers += [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_sizes[2],
|
||||
stride=1,
|
||||
padding=(kernel_sizes[2] - 1) // 2,
|
||||
bias=bias,
|
||||
),
|
||||
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
||||
)
|
||||
]
|
||||
self.layers += [
|
||||
torch.nn.Conv1d(
|
||||
out_chs,
|
||||
out_channels,
|
||||
kernel_size=kernel_sizes[3],
|
||||
stride=1,
|
||||
padding=(kernel_sizes[3] - 1) // 2,
|
||||
bias=bias,
|
||||
),
|
||||
]
|
||||
|
||||
if use_weight_norm and use_spectral_norm:
|
||||
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
|
||||
|
||||
# apply weight norm
|
||||
self.use_weight_norm = use_weight_norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
# apply spectral norm
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
if use_spectral_norm:
|
||||
self.apply_spectral_norm()
|
||||
|
||||
# backward compatibility
|
||||
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List[Tensor]: List of output tensors of each layer.
|
||||
|
||||
"""
|
||||
outs = []
|
||||
for f in self.layers:
|
||||
x = f(x)
|
||||
outs += [x]
|
||||
|
||||
return outs
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv1d):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
logging.debug(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def apply_spectral_norm(self):
|
||||
"""Apply spectral normalization module from all of the layers."""
|
||||
|
||||
def _apply_spectral_norm(m: torch.nn.Module):
|
||||
if isinstance(m, torch.nn.Conv1d):
|
||||
torch.nn.utils.spectral_norm(m)
|
||||
logging.debug(f"Spectral norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_spectral_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
logging.debug(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def remove_spectral_norm(self):
|
||||
"""Remove spectral normalization module from all of the layers."""
|
||||
|
||||
def _remove_spectral_norm(m):
|
||||
try:
|
||||
logging.debug(f"Spectral norm is removed from {m}.")
|
||||
torch.nn.utils.remove_spectral_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_spectral_norm)
|
||||
|
||||
def _load_state_dict_pre_hook(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Fix the compatibility of weight / spectral normalization issue.
|
||||
|
||||
Some pretrained models are trained with configs that use weight / spectral
|
||||
normalization, but actually, the norm is not applied. This causes the mismatch
|
||||
of the parameters with configs. To solve this issue, when parameter mismatch
|
||||
happens in loading pretrained model, we remove the norm from the current model.
|
||||
|
||||
See also:
|
||||
- https://github.com/espnet/espnet/pull/5240
|
||||
- https://github.com/espnet/espnet/pull/5249
|
||||
- https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
|
||||
|
||||
"""
|
||||
current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
|
||||
if self.use_weight_norm and any(
|
||||
[k.endswith("weight") for k in current_module_keys]
|
||||
):
|
||||
logging.warning(
|
||||
"It seems weight norm is not applied in the pretrained model but the"
|
||||
" current model uses it. To keep the compatibility, we remove the norm"
|
||||
" from the current model. This may cause unexpected behavior due to the"
|
||||
" parameter mismatch in finetuning. To avoid this issue, please change"
|
||||
" the following parameters in config to false:\n"
|
||||
" - discriminator_params.follow_official_norm\n"
|
||||
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
|
||||
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
|
||||
"\n"
|
||||
"See also:\n"
|
||||
" - https://github.com/espnet/espnet/pull/5240\n"
|
||||
" - https://github.com/espnet/espnet/pull/5249"
|
||||
)
|
||||
self.remove_weight_norm()
|
||||
self.use_weight_norm = False
|
||||
for k in current_module_keys:
|
||||
if k.endswith("weight_g") or k.endswith("weight_v"):
|
||||
del state_dict[k]
|
||||
|
||||
if self.use_spectral_norm and any(
|
||||
[k.endswith("weight") for k in current_module_keys]
|
||||
):
|
||||
logging.warning(
|
||||
"It seems spectral norm is not applied in the pretrained model but the"
|
||||
" current model uses it. To keep the compatibility, we remove the norm"
|
||||
" from the current model. This may cause unexpected behavior due to the"
|
||||
" parameter mismatch in finetuning. To avoid this issue, please change"
|
||||
" the following parameters in config to false:\n"
|
||||
" - discriminator_params.follow_official_norm\n"
|
||||
" - discriminator_params.scale_discriminator_params.use_weight_norm\n"
|
||||
" - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
|
||||
"\n"
|
||||
"See also:\n"
|
||||
" - https://github.com/espnet/espnet/pull/5240\n"
|
||||
" - https://github.com/espnet/espnet/pull/5249"
|
||||
)
|
||||
self.remove_spectral_norm()
|
||||
self.use_spectral_norm = False
|
||||
for k in current_module_keys:
|
||||
if (
|
||||
k.endswith("weight_u")
|
||||
or k.endswith("weight_v")
|
||||
or k.endswith("weight_orig")
|
||||
):
|
||||
del state_dict[k]
|
||||
|
||||
|
||||
class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
|
||||
"""HiFi-GAN multi-scale discriminator module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scales: int = 3,
|
||||
downsample_pooling: str = "AvgPool1d",
|
||||
# follow the official implementation setting
|
||||
downsample_pooling_params: Dict[str, Any] = {
|
||||
"kernel_size": 4,
|
||||
"stride": 2,
|
||||
"padding": 2,
|
||||
},
|
||||
discriminator_params: Dict[str, Any] = {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [15, 41, 5, 3],
|
||||
"channels": 128,
|
||||
"max_downsample_channels": 1024,
|
||||
"max_groups": 16,
|
||||
"bias": True,
|
||||
"downsample_scales": [2, 2, 4, 4, 1],
|
||||
"nonlinear_activation": "LeakyReLU",
|
||||
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||
},
|
||||
follow_official_norm: bool = False,
|
||||
):
|
||||
"""Initilize HiFiGAN multi-scale discriminator module.
|
||||
|
||||
Args:
|
||||
scales (int): Number of multi-scales.
|
||||
downsample_pooling (str): Pooling module name for downsampling of the
|
||||
inputs.
|
||||
downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
|
||||
module.
|
||||
discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
|
||||
discriminator module.
|
||||
follow_official_norm (bool): Whether to follow the norm setting of the
|
||||
official implementaion. The first discriminator uses spectral norm
|
||||
and the other discriminators use weight norm.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.discriminators = torch.nn.ModuleList()
|
||||
|
||||
# add discriminators
|
||||
for i in range(scales):
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
if follow_official_norm:
|
||||
if i == 0:
|
||||
params["use_weight_norm"] = False
|
||||
params["use_spectral_norm"] = True
|
||||
else:
|
||||
params["use_weight_norm"] = True
|
||||
params["use_spectral_norm"] = False
|
||||
self.discriminators += [HiFiGANScaleDiscriminator(**params)]
|
||||
self.pooling = None
|
||||
if scales > 1:
|
||||
self.pooling = getattr(torch.nn, downsample_pooling)(
|
||||
**downsample_pooling_params
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List[List[torch.Tensor]]: List of list of each discriminator outputs,
|
||||
which consists of eachlayer output tensors.
|
||||
|
||||
"""
|
||||
outs = []
|
||||
for f in self.discriminators:
|
||||
outs += [f(x)]
|
||||
if self.pooling is not None:
|
||||
x = self.pooling(x)
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
|
||||
"""HiFi-GAN multi-scale + multi-period discriminator module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Multi-scale discriminator related
|
||||
scales: int = 3,
|
||||
scale_downsample_pooling: str = "AvgPool1d",
|
||||
scale_downsample_pooling_params: Dict[str, Any] = {
|
||||
"kernel_size": 4,
|
||||
"stride": 2,
|
||||
"padding": 2,
|
||||
},
|
||||
scale_discriminator_params: Dict[str, Any] = {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [15, 41, 5, 3],
|
||||
"channels": 128,
|
||||
"max_downsample_channels": 1024,
|
||||
"max_groups": 16,
|
||||
"bias": True,
|
||||
"downsample_scales": [2, 2, 4, 4, 1],
|
||||
"nonlinear_activation": "LeakyReLU",
|
||||
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||
},
|
||||
follow_official_norm: bool = True,
|
||||
# Multi-period discriminator related
|
||||
periods: List[int] = [2, 3, 5, 7, 11],
|
||||
period_discriminator_params: Dict[str, Any] = {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [5, 3],
|
||||
"channels": 32,
|
||||
"downsample_scales": [3, 3, 3, 3, 1],
|
||||
"max_downsample_channels": 1024,
|
||||
"bias": True,
|
||||
"nonlinear_activation": "LeakyReLU",
|
||||
"nonlinear_activation_params": {"negative_slope": 0.1},
|
||||
"use_weight_norm": True,
|
||||
"use_spectral_norm": False,
|
||||
},
|
||||
):
|
||||
"""Initilize HiFiGAN multi-scale + multi-period discriminator module.
|
||||
|
||||
Args:
|
||||
scales (int): Number of multi-scales.
|
||||
scale_downsample_pooling (str): Pooling module name for downsampling of the
|
||||
inputs.
|
||||
scale_downsample_pooling_params (dict): Parameters for the above pooling
|
||||
module.
|
||||
scale_discriminator_params (dict): Parameters for hifi-gan scale
|
||||
discriminator module.
|
||||
follow_official_norm (bool): Whether to follow the norm setting of the
|
||||
official implementaion. The first discriminator uses spectral norm and
|
||||
the other discriminators use weight norm.
|
||||
periods (list): List of periods.
|
||||
period_discriminator_params (dict): Parameters for hifi-gan period
|
||||
discriminator module. The period parameter will be overwritten.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.msd = HiFiGANMultiScaleDiscriminator(
|
||||
scales=scales,
|
||||
downsample_pooling=scale_downsample_pooling,
|
||||
downsample_pooling_params=scale_downsample_pooling_params,
|
||||
discriminator_params=scale_discriminator_params,
|
||||
follow_official_norm=follow_official_norm,
|
||||
)
|
||||
self.mpd = HiFiGANMultiPeriodDiscriminator(
|
||||
periods=periods,
|
||||
discriminator_params=period_discriminator_params,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T).
|
||||
|
||||
Returns:
|
||||
List[List[Tensor]]: List of list of each discriminator outputs,
|
||||
which consists of each layer output tensors. Multi scale and
|
||||
multi period ones are concatenated.
|
||||
|
||||
"""
|
||||
msd_outs = self.msd(x)
|
||||
mpd_outs = self.mpd(x)
|
||||
return msd_outs + mpd_outs
|
@ -1,336 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""HiFiGAN-related loss modules.
|
||||
|
||||
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributions as D
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||
|
||||
|
||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||
"""Generator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators: bool = True,
|
||||
loss_type: str = "mse",
|
||||
):
|
||||
"""Initialize GeneratorAversarialLoss module.
|
||||
|
||||
Args:
|
||||
average_by_discriminators (bool): Whether to average the loss by
|
||||
the number of discriminators.
|
||||
loss_type (str): Loss type, "mse" or "hinge".
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||
if loss_type == "mse":
|
||||
self.criterion = self._mse_loss
|
||||
else:
|
||||
self.criterion = self._hinge_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Calcualate generator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||
outputs, list of discriminator outputs, or list of list of discriminator
|
||||
outputs..
|
||||
|
||||
Returns:
|
||||
Tensor: Generator adversarial loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
adv_loss = 0.0
|
||||
for i, outputs_ in enumerate(outputs):
|
||||
if isinstance(outputs_, (tuple, list)):
|
||||
# NOTE(kan-bayashi): case including feature maps
|
||||
outputs_ = outputs_[-1]
|
||||
adv_loss += self.criterion(outputs_)
|
||||
if self.average_by_discriminators:
|
||||
adv_loss /= i + 1
|
||||
else:
|
||||
adv_loss = self.criterion(outputs)
|
||||
|
||||
return adv_loss
|
||||
|
||||
def _mse_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _hinge_loss(self, x):
|
||||
return -x.mean()
|
||||
|
||||
|
||||
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
"""Discriminator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators: bool = True,
|
||||
loss_type: str = "mse",
|
||||
):
|
||||
"""Initialize DiscriminatorAversarialLoss module.
|
||||
|
||||
Args:
|
||||
average_by_discriminators (bool): Whether to average the loss by
|
||||
the number of discriminators.
|
||||
loss_type (str): Loss type, "mse" or "hinge".
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||
if loss_type == "mse":
|
||||
self.fake_criterion = self._mse_fake_loss
|
||||
self.real_criterion = self._mse_real_loss
|
||||
else:
|
||||
self.fake_criterion = self._hinge_fake_loss
|
||||
self.real_criterion = self._hinge_real_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calcualate discriminator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||
outputs, list of discriminator outputs, or list of list of discriminator
|
||||
outputs calculated from generator.
|
||||
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||
outputs, list of discriminator outputs, or list of list of discriminator
|
||||
outputs calculated from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Discriminator real loss value.
|
||||
Tensor: Discriminator fake loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
real_loss = 0.0
|
||||
fake_loss = 0.0
|
||||
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||
if isinstance(outputs_hat_, (tuple, list)):
|
||||
# NOTE(kan-bayashi): case including feature maps
|
||||
outputs_hat_ = outputs_hat_[-1]
|
||||
outputs_ = outputs_[-1]
|
||||
real_loss += self.real_criterion(outputs_)
|
||||
fake_loss += self.fake_criterion(outputs_hat_)
|
||||
if self.average_by_discriminators:
|
||||
fake_loss /= i + 1
|
||||
real_loss /= i + 1
|
||||
else:
|
||||
real_loss = self.real_criterion(outputs)
|
||||
fake_loss = self.fake_criterion(outputs_hat)
|
||||
|
||||
return real_loss, fake_loss
|
||||
|
||||
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||
|
||||
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
|
||||
|
||||
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
|
||||
|
||||
|
||||
class FeatureMatchLoss(torch.nn.Module):
|
||||
"""Feature matching loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_layers: bool = True,
|
||||
average_by_discriminators: bool = True,
|
||||
include_final_outputs: bool = False,
|
||||
):
|
||||
"""Initialize FeatureMatchLoss module.
|
||||
|
||||
Args:
|
||||
average_by_layers (bool): Whether to average the loss by the number
|
||||
of layers.
|
||||
average_by_discriminators (bool): Whether to average the loss by
|
||||
the number of discriminators.
|
||||
include_final_outputs (bool): Whether to include the final output of
|
||||
each discriminator for loss calculation.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.average_by_layers = average_by_layers
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
self.include_final_outputs = include_final_outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Calculate feature matching loss.
|
||||
|
||||
Args:
|
||||
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||
discriminator outputs or list of discriminator outputs calcuated
|
||||
from generator's outputs.
|
||||
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||
discriminator outputs or list of discriminator outputs calcuated
|
||||
from groundtruth..
|
||||
|
||||
Returns:
|
||||
Tensor: Feature matching loss value.
|
||||
|
||||
"""
|
||||
feat_match_loss = 0.0
|
||||
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
||||
feat_match_loss_ = 0.0
|
||||
if not self.include_final_outputs:
|
||||
feats_hat_ = feats_hat_[:-1]
|
||||
feats_ = feats_[:-1]
|
||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||
if self.average_by_layers:
|
||||
feat_match_loss_ /= j + 1
|
||||
feat_match_loss += feat_match_loss_
|
||||
if self.average_by_discriminators:
|
||||
feat_match_loss /= i + 1
|
||||
|
||||
return feat_match_loss
|
||||
|
||||
|
||||
class MelSpectrogramLoss(torch.nn.Module):
|
||||
"""Mel-spectrogram loss."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int = 22050,
|
||||
frame_length: int = 1024, # in samples
|
||||
frame_shift: int = 256, # in samples
|
||||
n_mels: int = 80,
|
||||
use_fft_mag: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.wav_to_mel = Wav2LogFilterBank(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=frame_length / sampling_rate, # in second
|
||||
frame_shift=frame_shift / sampling_rate, # in second
|
||||
use_fft_mag=use_fft_mag,
|
||||
num_filters=n_mels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
y_hat: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
return_mel: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Calculate Mel-spectrogram loss.
|
||||
|
||||
Args:
|
||||
y_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||
y (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
||||
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
||||
waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram loss value.
|
||||
|
||||
"""
|
||||
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
|
||||
mel = self.wav_to_mel(y.squeeze(1))
|
||||
mel_loss = F.l1_loss(mel_hat, mel)
|
||||
|
||||
if return_mel:
|
||||
return mel_loss, (mel_hat, mel)
|
||||
|
||||
return mel_loss
|
||||
|
||||
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
||||
|
||||
"""VITS-related loss modules.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class KLDivergenceLoss(torch.nn.Module):
|
||||
"""KL divergence loss."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
z_p: torch.Tensor,
|
||||
logs_q: torch.Tensor,
|
||||
m_p: torch.Tensor,
|
||||
logs_p: torch.Tensor,
|
||||
z_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate KL divergence loss.
|
||||
|
||||
Args:
|
||||
z_p (Tensor): Flow hidden representation (B, H, T_feats).
|
||||
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
||||
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
||||
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
||||
z_mask (Tensor): Mask tensor (B, 1, T_feats).
|
||||
|
||||
Returns:
|
||||
Tensor: KL divergence loss.
|
||||
|
||||
"""
|
||||
z_p = z_p.float()
|
||||
logs_q = logs_q.float()
|
||||
m_p = m_p.float()
|
||||
logs_p = logs_p.float()
|
||||
z_mask = z_mask.float()
|
||||
kl = logs_p - logs_q - 0.5
|
||||
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
||||
kl = torch.sum(kl * z_mask)
|
||||
loss = kl / torch.sum(z_mask)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
||||
"""KL divergence loss without flow."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m_q: torch.Tensor,
|
||||
logs_q: torch.Tensor,
|
||||
m_p: torch.Tensor,
|
||||
logs_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate KL divergence loss without flow.
|
||||
|
||||
Args:
|
||||
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
|
||||
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
||||
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
||||
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
||||
"""
|
||||
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
|
||||
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
||||
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
||||
return loss
|
@ -1,117 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Posterior encoder module in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from wavenet import WaveNet, Conv1d
|
||||
|
||||
|
||||
class PosteriorEncoder(torch.nn.Module):
|
||||
"""Posterior encoder module in VITS.
|
||||
|
||||
This is a module of posterior encoder described in `Conditional Variational
|
||||
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 513,
|
||||
out_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
kernel_size: int = 5,
|
||||
layers: int = 16,
|
||||
stacks: int = 1,
|
||||
base_dilation: int = 1,
|
||||
global_channels: int = -1,
|
||||
dropout_rate: float = 0.0,
|
||||
bias: bool = True,
|
||||
use_weight_norm: bool = True,
|
||||
):
|
||||
"""Initilialize PosteriorEncoder module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size in WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of repeat stacking of WaveNet.
|
||||
base_dilation (int): Base dilation factor.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
bias (bool): Whether to use bias parameters in conv.
|
||||
use_weight_norm (bool): Whether to apply weight norm.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# define modules
|
||||
self.input_conv = Conv1d(in_channels, hidden_channels, 1)
|
||||
self.encoder = WaveNet(
|
||||
in_channels=-1,
|
||||
out_channels=-1,
|
||||
kernel_size=kernel_size,
|
||||
layers=layers,
|
||||
stacks=stacks,
|
||||
base_dilation=base_dilation,
|
||||
residual_channels=hidden_channels,
|
||||
aux_channels=-1,
|
||||
gate_channels=hidden_channels * 2,
|
||||
skip_channels=hidden_channels,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
use_weight_norm=use_weight_norm,
|
||||
use_first_conv=False,
|
||||
use_last_conv=False,
|
||||
scale_residual=False,
|
||||
scale_skip_connect=True,
|
||||
)
|
||||
self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T_feats).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
|
||||
Tensor: Projected mean tensor (B, out_channels, T_feats).
|
||||
Tensor: Projected scale tensor (B, out_channels, T_feats).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_feats).
|
||||
|
||||
"""
|
||||
x_mask = (
|
||||
(~make_pad_mask(x_lengths))
|
||||
.unsqueeze(1)
|
||||
.to(
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
)
|
||||
x = self.input_conv(x) * x_mask
|
||||
x = self.encoder(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
|
||||
return z, m, logs, x_mask
|
@ -1,229 +0,0 @@
|
||||
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
|
||||
|
||||
# Copyright 2021 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Residual affine coupling modules in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from flow import FlipFlow
|
||||
from wavenet import WaveNet
|
||||
|
||||
|
||||
class ResidualAffineCouplingBlock(torch.nn.Module):
|
||||
"""Residual affine coupling block module.
|
||||
|
||||
This is a module of residual affine coupling block, which used as "Flow" in
|
||||
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`_.
|
||||
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
flows: int = 4,
|
||||
kernel_size: int = 5,
|
||||
base_dilation: int = 1,
|
||||
layers: int = 4,
|
||||
global_channels: int = -1,
|
||||
dropout_rate: float = 0.0,
|
||||
use_weight_norm: bool = True,
|
||||
bias: bool = True,
|
||||
use_only_mean: bool = True,
|
||||
):
|
||||
"""Initilize ResidualAffineCouplingBlock module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
flows (int): Number of flows.
|
||||
kernel_size (int): Kernel size for WaveNet.
|
||||
base_dilation (int): Base dilation factor for WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of stacks of WaveNet.
|
||||
global_channels (int): Number of global channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||
bias (bool): Whether to use bias paramters in WaveNet.
|
||||
use_only_mean (bool): Whether to estimate only mean.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.flows = torch.nn.ModuleList()
|
||||
for i in range(flows):
|
||||
self.flows += [
|
||||
ResidualAffineCouplingLayer(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
base_dilation=base_dilation,
|
||||
layers=layers,
|
||||
stacks=1,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
use_weight_norm=use_weight_norm,
|
||||
bias=bias,
|
||||
use_only_mean=use_only_mean,
|
||||
)
|
||||
]
|
||||
self.flows += [FlipFlow()]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
inverse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, in_channels, T).
|
||||
|
||||
"""
|
||||
if not inverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, inverse=inverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, inverse=inverse)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAffineCouplingLayer(torch.nn.Module):
|
||||
"""Residual affine coupling layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 192,
|
||||
hidden_channels: int = 192,
|
||||
kernel_size: int = 5,
|
||||
base_dilation: int = 1,
|
||||
layers: int = 5,
|
||||
stacks: int = 1,
|
||||
global_channels: int = -1,
|
||||
dropout_rate: float = 0.0,
|
||||
use_weight_norm: bool = True,
|
||||
bias: bool = True,
|
||||
use_only_mean: bool = True,
|
||||
):
|
||||
"""Initialzie ResidualAffineCouplingLayer module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size for WaveNet.
|
||||
base_dilation (int): Base dilation factor for WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of stacks of WaveNet.
|
||||
global_channels (int): Number of global channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||
bias (bool): Whether to use bias paramters in WaveNet.
|
||||
use_only_mean (bool): Whether to estimate only mean.
|
||||
|
||||
"""
|
||||
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.half_channels = in_channels // 2
|
||||
self.use_only_mean = use_only_mean
|
||||
|
||||
# define modules
|
||||
self.input_conv = torch.nn.Conv1d(
|
||||
self.half_channels,
|
||||
hidden_channels,
|
||||
1,
|
||||
)
|
||||
self.encoder = WaveNet(
|
||||
in_channels=-1,
|
||||
out_channels=-1,
|
||||
kernel_size=kernel_size,
|
||||
layers=layers,
|
||||
stacks=stacks,
|
||||
base_dilation=base_dilation,
|
||||
residual_channels=hidden_channels,
|
||||
aux_channels=-1,
|
||||
gate_channels=hidden_channels * 2,
|
||||
skip_channels=hidden_channels,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
use_weight_norm=use_weight_norm,
|
||||
use_first_conv=False,
|
||||
use_last_conv=False,
|
||||
scale_residual=False,
|
||||
scale_skip_connect=True,
|
||||
)
|
||||
if use_only_mean:
|
||||
self.proj = torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
self.half_channels,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.proj = torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
self.half_channels * 2,
|
||||
1,
|
||||
)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
inverse: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, in_channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
|
||||
"""
|
||||
xa, xb = x.split(x.size(1) // 2, dim=1)
|
||||
h = self.input_conv(xa) * x_mask
|
||||
h = self.encoder(h, x_mask, g=g)
|
||||
stats = self.proj(h) * x_mask
|
||||
if not self.use_only_mean:
|
||||
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not inverse:
|
||||
xb = m + xb * torch.exp(logs) * x_mask
|
||||
x = torch.cat([xa, xb], 1)
|
||||
logdet = torch.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
xb = (xb - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([xa, xb], 1)
|
||||
return x
|
@ -1,684 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Text encoder module in VITS.
|
||||
|
||||
This code is based on
|
||||
- https://github.com/jaywalnut310/vits
|
||||
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
|
||||
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import is_jit_tracing, make_pad_mask
|
||||
|
||||
|
||||
class TextEncoder(torch.nn.Module):
|
||||
"""Text encoder module in VITS.
|
||||
|
||||
This is a module of text encoder described in `Conditional Variational Autoencoder
|
||||
with Adversarial Learning for End-to-End Text-to-Speech`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocabs: int,
|
||||
d_model: int = 192,
|
||||
num_heads: int = 2,
|
||||
dim_feedforward: int = 768,
|
||||
cnn_module_kernel: int = 5,
|
||||
num_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""Initialize TextEncoder module.
|
||||
|
||||
Args:
|
||||
vocabs (int): Vocabulary size.
|
||||
d_model (int): attention dimension
|
||||
num_heads (int): number of attention heads
|
||||
dim_feedforward (int): feedforward dimention
|
||||
cnn_module_kernel (int): convolution kernel size
|
||||
num_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
|
||||
# define modules
|
||||
self.emb = torch.nn.Embedding(vocabs, d_model)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
|
||||
|
||||
# We use conformer as text encoder
|
||||
self.encoder = Transformer(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
dim_feedforward=dim_feedforward,
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input index tensor (B, T_text).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||
|
||||
"""
|
||||
# (B, T_text, embed_dim)
|
||||
x = self.emb(x) * math.sqrt(self.d_model)
|
||||
|
||||
assert x.size(1) == x_lengths.max().item()
|
||||
|
||||
# (B, T_text)
|
||||
pad_mask = make_pad_mask(x_lengths)
|
||||
|
||||
# encoder assume the channel last (B, T_text, embed_dim)
|
||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||
|
||||
# convert the channel first (B, embed_dim, T_text)
|
||||
x = x.transpose(1, 2)
|
||||
non_pad_mask = (~pad_mask).unsqueeze(1)
|
||||
stats = self.proj(x) * non_pad_mask
|
||||
m, logs = stats.split(stats.size(1) // 2, dim=1)
|
||||
|
||||
return x, m, logs, non_pad_mask
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
d_model (int): attention dimension
|
||||
num_heads (int): number of attention heads
|
||||
dim_feedforward (int): feedforward dimention
|
||||
cnn_module_kernel (int): convolution kernel size
|
||||
num_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 192,
|
||||
num_heads: int = 2,
|
||||
dim_feedforward: int = 768,
|
||||
cnn_module_kernel: int = 5,
|
||||
num_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.d_model = d_model
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
num_heads=num_heads,
|
||||
dim_feedforward=dim_feedforward,
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, key_padding_mask: Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
lengths:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
"""
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
||||
|
||||
x = self.after_norm(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
TransformerEncoderLayer is made up of self-attn and feedforward.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input.
|
||||
num_heads: the number of heads in the multi-head attention models.
|
||||
dim_feedforward: the dimension of the feed-forward network model.
|
||||
dropout: the dropout value (default=0.1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
num_heads: int,
|
||||
dim_feedforward: int,
|
||||
cnn_module_kernel: int,
|
||||
dropout: float = 0.1,
|
||||
) -> None:
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, num_heads, dropout=dropout
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
|
||||
self.ff_scale = 0.5
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the transformer encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
|
||||
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||
"""
|
||||
# macaron style feed-forward module
|
||||
src = src + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||
)
|
||||
|
||||
# multi-head self-attention module
|
||||
src_attn = self.self_attn(
|
||||
self.norm_mha(src),
|
||||
pos_emb=pos_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
)
|
||||
src = src + self.dropout(src_attn)
|
||||
|
||||
# convolution module
|
||||
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
|
||||
|
||||
# feed-forward module
|
||||
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
|
||||
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
r"""TransformerEncoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the TransformerEncoderLayer class.
|
||||
num_layers: the number of sub-encoder-layers in the encoder.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
|
||||
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||
"""
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
key_padding_mask=key_padding_mask,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||
|
||||
Args:
|
||||
d_model: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size = x.size(1)
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x_size * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vector and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x_size, self.d_model)
|
||||
pe_negative = torch.zeros(x_size, self.d_model)
|
||||
position = torch.arange(0, x_size, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, seq_len, 2*seq_len-1).
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, seq_len, seq_len)
|
||||
"""
|
||||
(batch_size, num_heads, seq_len, n) = x.shape
|
||||
|
||||
if not is_jit_tracing():
|
||||
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
|
||||
|
||||
if is_jit_tracing():
|
||||
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
x = x.reshape(-1, n)
|
||||
x = torch.gather(x, dim=1, index=indexes)
|
||||
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
|
||||
return x
|
||||
else:
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, seq_len, seq_len),
|
||||
(batch_stride, head_stride, time_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (seq_len - 1),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor of shape (seq_len, batch_size, embed_dim)
|
||||
pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
Its shape is (batch_size, seq_len).
|
||||
|
||||
Outputs:
|
||||
A tensor of shape (seq_len, batch_size, embed_dim).
|
||||
"""
|
||||
seq_len, batch_size, _ = x.shape
|
||||
scaling = float(self.head_dim) ** -0.5
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(seq_len, batch_size * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
|
||||
|
||||
p = self.linear_pos(pos_emb).view(
|
||||
pos_emb.size(0), -1, self.num_heads, self.head_dim
|
||||
)
|
||||
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
# (batch_size, num_head, seq_len, head_dim)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch_size, num_head, seq_len, seq_len)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p
|
||||
) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||
matrix_bd = self.rel_shift(
|
||||
matrix_bd
|
||||
) # (batch_size, num_head, seq_len, seq_len)
|
||||
|
||||
# (batch_size, num_head, seq_len, seq_len)
|
||||
attn_output_weights = (matrix_ac + matrix_bd) * scaling
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
batch_size * self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
batch_size, self.num_heads, seq_len, seq_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
batch_size * self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
# (batch_size * num_head, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert attn_output.shape == (
|
||||
batch_size * self.num_heads,
|
||||
seq_len,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, batch_size, self.embed_dim)
|
||||
)
|
||||
# (seq_len, batch_size, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = Swish()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
if src_key_padding_mask is not None:
|
||||
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
x = self.depthwise_conv(x)
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def _test_text_encoder():
|
||||
vocabs = 500
|
||||
d_model = 192
|
||||
batch_size = 5
|
||||
seq_len = 100
|
||||
|
||||
m = TextEncoder(vocabs=vocabs, d_model=d_model)
|
||||
x, m, logs, mask = m(
|
||||
x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
|
||||
x_lengths=torch.full((batch_size,), seq_len),
|
||||
)
|
||||
print(x.shape, m.shape, logs.shape, mask.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_text_encoder()
|
@ -1,108 +0,0 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import g2p_en
|
||||
import tacotron_cleaner.cleaners
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
def __init__(self, tokens: str):
|
||||
"""
|
||||
Args:
|
||||
tokens: the file that maps tokens to ids
|
||||
"""
|
||||
# Parse token file
|
||||
self.token2id: Dict[str, int] = {}
|
||||
with open(tokens, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split()
|
||||
if len(info) == 1:
|
||||
# case of space
|
||||
token = " "
|
||||
id = int(info[0])
|
||||
else:
|
||||
token, id = info[0], int(info[1])
|
||||
self.token2id[token] = id
|
||||
|
||||
self.blank_id = self.token2id["<blk>"]
|
||||
self.oov_id = self.token2id["<unk>"]
|
||||
self.vocab_size = len(self.token2id)
|
||||
|
||||
self.g2p = g2p_en.G2p()
|
||||
|
||||
def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
|
||||
"""
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts.
|
||||
intersperse_blank:
|
||||
Whether to intersperse blanks in the token sequence.
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
token_ids_list = []
|
||||
|
||||
for text in texts:
|
||||
# Text normalization
|
||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||
# Convert to phonemes
|
||||
tokens = self.g2p(text)
|
||||
token_ids = []
|
||||
for t in tokens:
|
||||
if t in self.token2id:
|
||||
token_ids.append(self.token2id[t])
|
||||
else:
|
||||
token_ids.append(self.oov_id)
|
||||
|
||||
if intersperse_blank:
|
||||
token_ids = intersperse(token_ids, self.blank_id)
|
||||
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
return token_ids_list
|
||||
|
||||
def tokens_to_token_ids(
|
||||
self, tokens_list: List[str], intersperse_blank: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tokens_list:
|
||||
A list of token list, each corresponding to one utterance.
|
||||
intersperse_blank:
|
||||
Whether to intersperse blanks in the token sequence.
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
token_ids_list = []
|
||||
|
||||
for tokens in tokens_list:
|
||||
token_ids = []
|
||||
for t in tokens:
|
||||
if t in self.token2id:
|
||||
token_ids.append(self.token2id[t])
|
||||
else:
|
||||
token_ids.append(self.oov_id)
|
||||
|
||||
if intersperse_blank:
|
||||
token_ids = intersperse(token_ids, self.blank_id)
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
return token_ids_list
|
Loading…
x
Reference in New Issue
Block a user