use Conformer as text encoder

This commit is contained in:
yaozengwei 2023-11-05 18:25:47 +08:00
parent b719581e2f
commit 8d09f8e6bf
4 changed files with 143 additions and 4 deletions

View File

@ -44,6 +44,7 @@ class VITSGenerator(torch.nn.Module):
segment_size: int = 32, segment_size: int = 32,
text_encoder_attention_heads: int = 2, text_encoder_attention_heads: int = 2,
text_encoder_ffn_expand: int = 4, text_encoder_ffn_expand: int = 4,
text_encoder_cnn_module_kernel: int = 5,
text_encoder_blocks: int = 6, text_encoder_blocks: int = 6,
text_encoder_dropout_rate: float = 0.1, text_encoder_dropout_rate: float = 0.1,
decoder_kernel_size: int = 7, decoder_kernel_size: int = 7,
@ -89,6 +90,7 @@ class VITSGenerator(torch.nn.Module):
of text encoder. of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder. of text encoder.
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder. text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder. text encoder.
@ -135,6 +137,7 @@ class VITSGenerator(torch.nn.Module):
d_model=hidden_channels, d_model=hidden_channels,
num_heads=text_encoder_attention_heads, num_heads=text_encoder_attention_heads,
dim_feedforward=hidden_channels * text_encoder_ffn_expand, dim_feedforward=hidden_channels * text_encoder_ffn_expand,
cnn_module_kernel=text_encoder_cnn_module_kernel,
num_layers=text_encoder_blocks, num_layers=text_encoder_blocks,
dropout=text_encoder_dropout_rate, dropout=text_encoder_dropout_rate,
) )

View File

@ -103,11 +103,13 @@ from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from train2 import get_model, get_params from train import get_model, get_params, prepare_input
from tokenizer import Tokenizer
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -124,7 +126,6 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
from tts_datamodule import LJSpeechTtsDataModule from tts_datamodule import LJSpeechTtsDataModule
from utils import prepare_token_batch
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -169,6 +170,13 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to tokens.txt.""",
)
return parser return parser
@ -176,6 +184,7 @@ def infer_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
tokenizer: Tokenizer,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -236,10 +245,16 @@ def infer_dataset(
# We only want one background worker so that serialization is deterministic. # We only want one background worker so that serialization is deterministic.
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
batch_size = len(batch["text"]) batch_size = len(batch["text"])
text = batch["text"] text = batch["text"]
tokens, tokens_lens = prepare_token_batch(text) tokens = tokenizer.texts_to_token_ids(text)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device) tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device) tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
audio = batch["audio"] audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist() audio_lens = batch["audio_lens"].tolist()
@ -296,6 +311,11 @@ def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
logging.info(params) logging.info(params)
@ -348,6 +368,7 @@ def main():
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer,
) )
# save_results( # save_results(

View File

@ -45,6 +45,7 @@ class TextEncoder(torch.nn.Module):
d_model: int = 192, d_model: int = 192,
num_heads: int = 2, num_heads: int = 2,
dim_feedforward: int = 768, dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6, num_layers: int = 6,
dropout: float = 0.1, dropout: float = 0.1,
): ):
@ -55,6 +56,7 @@ class TextEncoder(torch.nn.Module):
d_model (int): attention dimension d_model (int): attention dimension
num_heads (int): number of attention heads num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers num_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
""" """
@ -69,6 +71,7 @@ class TextEncoder(torch.nn.Module):
d_model=d_model, d_model=d_model,
num_heads=num_heads, num_heads=num_heads,
dim_feedforward=dim_feedforward, dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
num_layers=num_layers, num_layers=num_layers,
dropout=dropout, dropout=dropout,
) )
@ -119,6 +122,7 @@ class Transformer(nn.Module):
d_model (int): attention dimension d_model (int): attention dimension
num_heads (int): number of attention heads num_heads (int): number of attention heads
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimention
cnn_module_kernel (int): convolution kernel size
num_layers (int): number of encoder layers num_layers (int): number of encoder layers
dropout (float): dropout rate dropout (float): dropout rate
""" """
@ -128,6 +132,7 @@ class Transformer(nn.Module):
d_model: int = 192, d_model: int = 192,
num_heads: int = 2, num_heads: int = 2,
dim_feedforward: int = 768, dim_feedforward: int = 768,
cnn_module_kernel: int = 5,
num_layers: int = 6, num_layers: int = 6,
dropout: float = 0.1, dropout: float = 0.1,
) -> None: ) -> None:
@ -142,6 +147,7 @@ class Transformer(nn.Module):
d_model=d_model, d_model=d_model,
num_heads=num_heads, num_heads=num_heads,
dim_feedforward=dim_feedforward, dim_feedforward=dim_feedforward,
cnn_module_kernel=cnn_module_kernel,
dropout=dropout, dropout=dropout,
) )
self.encoder = TransformerEncoder(encoder_layer, num_layers) self.encoder = TransformerEncoder(encoder_layer, num_layers)
@ -187,12 +193,22 @@ class TransformerEncoderLayer(nn.Module):
d_model: int, d_model: int,
num_heads: int, num_heads: int,
dim_feedforward: int, dim_feedforward: int,
cnn_module_kernel: int,
dropout: float = 0.1, dropout: float = 0.1,
) -> None: ) -> None:
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, dim_feedforward),
Swish(), Swish(),
@ -200,10 +216,13 @@ class TransformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.ff_scale = 0.5
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward( def forward(
@ -220,6 +239,9 @@ class TransformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
""" """
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
# multi-head self-attention module # multi-head self-attention module
src_attn = self.self_attn( src_attn = self.self_attn(
self.norm_mha(src), self.norm_mha(src),
@ -228,6 +250,9 @@ class TransformerEncoderLayer(nn.Module):
) )
src = src + self.dropout(src_attn) src = src + self.dropout(src_attn)
# convolution module
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
# feed-forward module # feed-forward module
src = src + self.dropout(self.feed_forward(self.norm_ff(src))) src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
@ -508,6 +533,95 @@ class RelPositionMultiheadAttention(nn.Module):
return attn_output return attn_output
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self,
channels: int,
kernel_size: int,
bias: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()
def forward(
self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
class Swish(nn.Module): class Swish(nn.Module):
"""Construct an Swish object.""" """Construct an Swish object."""

View File

@ -61,6 +61,7 @@ class VITS(nn.Module):
"segment_size": 32, "segment_size": 32,
"text_encoder_attention_heads": 2, "text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4, "text_encoder_ffn_expand": 4,
"text_encoder_cnn_module_kernel": 5,
"text_encoder_blocks": 6, "text_encoder_blocks": 6,
"text_encoder_dropout_rate": 0.1, "text_encoder_dropout_rate": 0.1,
"decoder_kernel_size": 7, "decoder_kernel_size": 7,