icefall/egs/ljspeech/TTS/vits/text_encoder.py
zr_jin 735fb9a73d
A TTS recipe VITS on VCTK dataset (#1380)
* init

* isort formatted

* minor updates

* Create shared

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare.sh

* updated

* Update train.py

* Update train.py

* Update tts_datamodule.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* fixed formatting issue

* Update infer.py

* removed redundant files

* Create monotonic_align

* removed redundant files

* created symlinks

* Update prepare.sh

* minor adjustments

* Create requirements_tts.txt

* Update requirements_tts.txt

added version constraints

* Update infer.py

* Update infer.py

* Update infer.py

* updated docs

* Update export-onnx.py

* Update export-onnx.py

* Update test_onnx.py

* updated requirements.txt

* Update test_onnx.py

* Update test_onnx.py

* docs updated

* docs fixed

* minor updates
2023-12-06 09:59:19 +08:00

685 lines
23 KiB
Python

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on
- https://github.com/jaywalnut310/vits
- https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
- https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
"""
import copy
import math
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.utils import is_jit_tracing, make_pad_mask
class TextEncoder(torch.nn.Module):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`.
"""
def __init__(
self,
vocabs: int,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
super().__init__()
self.d_model = d_model
# define modules
self.emb = torch.nn.Embedding(vocabs, d_model)
torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
# We use conformer as text encoder
self.encoder = Transformer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
num_layers=num_layers,
dropout=dropout,
)
self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
# (B, T_text, embed_dim)
x = self.emb(x) * math.sqrt(self.d_model)
assert x.size(1) == x_lengths.max().item()
# (B, T_text)
pad_mask = make_pad_mask(x_lengths)
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)
non_pad_mask = (~pad_mask).unsqueeze(1)
stats = self.proj(x) * non_pad_mask
m, logs = stats.split(stats.size(1) // 2, dim=1)
return x, m, logs, non_pad_mask
class Transformer(nn.Module):
"""
Args:
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
def __init__(
self,
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
) -> None:
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
dropout=dropout,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
self.after_norm = nn.LayerNorm(d_model)
def forward(
self, x: Tensor, key_padding_mask: Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
lengths:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
"""
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.after_norm(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x
class TransformerEncoderLayer(nn.Module):
"""
TransformerEncoderLayer is made up of self-attn and feedforward.
Args:
d_model: the number of expected features in the input.
num_heads: the number of heads in the multi-head attention models.
dim_feedforward: the dimension of the feed-forward network model.
dropout: the dropout value (default=0.1).
"""
def __init__(
self,
d_model: int,
num_heads: int,
dim_feedforward: int,
cnn_module_kernel: int,
dropout: float = 0.1,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.self_attn = RelPositionMultiheadAttention(
d_model, num_heads, dropout=dropout
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.ff_scale = 0.5
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the transformer encoder layer.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(
self.feed_forward_macaron(self.norm_ff_macaron(src))
)
# multi-head self-attention module
src_attn = self.self_attn(
self.norm_mha(src),
pos_emb=pos_emb,
key_padding_mask=key_padding_mask,
)
src = src + self.dropout(src_attn)
# convolution module
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
# feed-forward module
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
src = self.norm_final(src)
return src
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer class.
num_layers: the number of sub-encoder-layers in the encoder.
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
def forward(
self,
src: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
output = src
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
key_padding_mask=key_padding_mask,
)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
x_size = x.size(1)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x_size * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size, self.d_model)
pe_negative = torch.zeros(x_size, self.d_model)
position = torch.arange(0, x_size, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x: Tensor) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, seq_len, 2*seq_len-1).
Returns:
Tensor: tensor of shape (batch, head, seq_len, seq_len)
"""
(batch_size, num_heads, seq_len, n) = x.shape
if not is_jit_tracing():
assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
if is_jit_tracing():
rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, seq_len, seq_len)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, seq_len, seq_len),
(batch_stride, head_stride, time_stride - n_stride, n_stride),
storage_offset=n_stride * (seq_len - 1),
)
def forward(
self,
x: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x: Input tensor of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
Its shape is (batch_size, seq_len).
Outputs:
A tensor of shape (seq_len, batch_size, embed_dim).
"""
seq_len, batch_size, _ = x.shape
scaling = float(self.head_dim) ** -0.5
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
v = (
v.contiguous()
.view(seq_len, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
p = self.linear_pos(pos_emb).view(
pos_emb.size(0), -1, self.num_heads, self.head_dim
)
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
# (batch_size, num_head, seq_len, head_dim)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p
) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(
matrix_bd
) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
attn_output_weights = attn_output_weights.view(
batch_size, self.num_heads, seq_len, seq_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=self.dropout, training=self.training
)
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (
batch_size * self.num_heads,
seq_len,
self.head_dim,
)
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(nn.Module):
"""Construct an Swish object."""
def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
def _test_text_encoder():
vocabs = 500
d_model = 192
batch_size = 5
seq_len = 100
m = TextEncoder(vocabs=vocabs, d_model=d_model)
x, m, logs, mask = m(
x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
x_lengths=torch.full((batch_size,), seq_len),
)
print(x.shape, m.shape, logs.shape, mask.shape)
if __name__ == "__main__":
_test_text_encoder()