use Conformer as text encoder

This commit is contained in:
yaozengwei 2023-11-05 18:25:47 +08:00
parent b719581e2f
commit 8d09f8e6bf
4 changed files with 143 additions and 4 deletions

View File

@ -44,6 +44,7 @@ class VITSGenerator(torch.nn.Module):
segment_size: int = 32,
text_encoder_attention_heads: int = 2,
text_encoder_ffn_expand: int = 4,
text_encoder_cnn_module_kernel: int = 5,
text_encoder_blocks: int = 6,
text_encoder_dropout_rate: float = 0.1,
decoder_kernel_size: int = 7,
@ -89,6 +90,7 @@ class VITSGenerator(torch.nn.Module):
of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder.
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder.
@ -135,6 +137,7 @@ class VITSGenerator(torch.nn.Module):
d_model=hidden_channels,
num_heads=text_encoder_attention_heads,
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
cnn_module_kernel=text_encoder_cnn_module_kernel,
num_layers=text_encoder_blocks,
dropout=text_encoder_dropout_rate,
)

View File

@ -103,11 +103,13 @@ from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import torchaudio
from train2 import get_model, get_params
from train import get_model, get_params, prepare_input
from tokenizer import Tokenizer
from icefall.checkpoint import (
average_checkpoints,
@ -124,7 +126,6 @@ from icefall.utils import (
write_error_stats,
)
from tts_datamodule import LJSpeechTtsDataModule
from utils import prepare_token_batch
LOG_EPS = math.log(1e-10)
@ -169,6 +170,13 @@ def get_parser():
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to tokens.txt.""",
)
return parser
@ -176,6 +184,7 @@ def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -236,10 +245,16 @@ def infer_dataset(
# We only want one background worker so that serialization is deterministic.
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"])
text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text)
tokens = tokenizer.texts_to_token_ids(text)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
@ -296,6 +311,11 @@ def main():
if torch.cuda.is_available():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(f"Device: {device}")
logging.info(params)
@ -348,6 +368,7 @@ def main():
dl=test_dl,
params=params,
model=model,
tokenizer=tokenizer,
)
# save_results(

View File

@ -45,6 +45,7 @@ class TextEncoder(torch.nn.Module):
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
):
@ -55,6 +56,7 @@ class TextEncoder(torch.nn.Module):
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
@ -69,6 +71,7 @@ class TextEncoder(torch.nn.Module):
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
num_layers=num_layers,
dropout=dropout,
)
@ -119,6 +122,7 @@ class Transformer(nn.Module):
d_model (int): attention dimension
num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers
dropout (float): dropout rate
"""
@ -128,6 +132,7 @@ class Transformer(nn.Module):
d_model: int = 192,
num_heads: int = 2,
dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6,
dropout: float = 0.1,
) -> None:
@ -142,6 +147,7 @@ class Transformer(nn.Module):
d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
dropout=dropout,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
@ -187,12 +193,22 @@ class TransformerEncoderLayer(nn.Module):
d_model: int,
num_heads: int,
dim_feedforward: int,
cnn_module_kernel: int,
dropout: float = 0.1,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
@ -200,10 +216,13 @@ class TransformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.ff_scale = 0.5
self.dropout = nn.Dropout(dropout)
def forward(
@ -220,6 +239,9 @@ class TransformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
# multi-head self-attention module
src_attn = self.self_attn(
self.norm_mha(src),
@ -228,6 +250,9 @@ class TransformerEncoderLayer(nn.Module):
)
src = src + self.dropout(src_attn)
# convolution module
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
# feed-forward module
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
@ -508,6 +533,95 @@ class RelPositionMultiheadAttention(nn.Module):
return attn_output
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(nn.Module):
"""Construct an Swish object."""

View File

@ -61,6 +61,7 @@ class VITS(nn.Module):
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_cnn_module_kernel": 5,
"text_encoder_blocks": 6,
"text_encoder_dropout_rate": 0.1,
"decoder_kernel_size": 7,