From 8d09f8e6bfa642eae3635104d1884014ad165d23 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 5 Nov 2023 18:25:47 +0800 Subject: [PATCH] use Conformer as text encoder --- egs/ljspeech/tts/vits/generator.py | 3 + egs/ljspeech/tts/vits/infer.py | 27 +++++- egs/ljspeech/tts/vits/text_encoder.py | 116 +++++++++++++++++++++++++- egs/ljspeech/tts/vits/vits.py | 1 + 4 files changed, 143 insertions(+), 4 deletions(-) diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py index a74440c95..fc0d45cfd 100644 --- a/egs/ljspeech/tts/vits/generator.py +++ b/egs/ljspeech/tts/vits/generator.py @@ -44,6 +44,7 @@ class VITSGenerator(torch.nn.Module): segment_size: int = 32, text_encoder_attention_heads: int = 2, text_encoder_ffn_expand: int = 4, + text_encoder_cnn_module_kernel: int = 5, text_encoder_blocks: int = 6, text_encoder_dropout_rate: float = 0.1, decoder_kernel_size: int = 7, @@ -89,6 +90,7 @@ class VITSGenerator(torch.nn.Module): of text encoder. text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block of text encoder. + text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder. text_encoder_blocks (int): Number of conformer blocks in text encoder. text_encoder_dropout_rate (float): Dropout rate in conformer block of text encoder. @@ -135,6 +137,7 @@ class VITSGenerator(torch.nn.Module): d_model=hidden_channels, num_heads=text_encoder_attention_heads, dim_feedforward=hidden_channels * text_encoder_ffn_expand, + cnn_module_kernel=text_encoder_cnn_module_kernel, num_layers=text_encoder_blocks, dropout=text_encoder_dropout_rate, ) diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/tts/vits/infer.py index 89fc72962..623cc3ec9 100755 --- a/egs/ljspeech/tts/vits/infer.py +++ b/egs/ljspeech/tts/vits/infer.py @@ -103,11 +103,13 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List, Optional, Tuple +import k2 import torch import torch.nn as nn import torchaudio -from train2 import get_model, get_params +from train import get_model, get_params, prepare_input +from tokenizer import Tokenizer from icefall.checkpoint import ( average_checkpoints, @@ -124,7 +126,6 @@ from icefall.utils import ( write_error_stats, ) from tts_datamodule import LJSpeechTtsDataModule -from utils import prepare_token_batch LOG_EPS = math.log(1e-10) @@ -169,6 +170,13 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to tokens.txt.""", + ) + return parser @@ -176,6 +184,7 @@ def infer_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, + tokenizer: Tokenizer, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -236,10 +245,16 @@ def infer_dataset( # We only want one background worker so that serialization is deterministic. for batch_idx, batch in enumerate(dl): batch_size = len(batch["text"]) + text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) + tokens = tokenizer.texts_to_token_ids(text) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) audio = batch["audio"] audio_lens = batch["audio_lens"].tolist() @@ -296,6 +311,11 @@ def main(): if torch.cuda.is_available(): device = torch.device("cuda", 0) + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + logging.info(f"Device: {device}") logging.info(params) @@ -348,6 +368,7 @@ def main(): dl=test_dl, params=params, model=model, + tokenizer=tokenizer, ) # save_results( diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/tts/vits/text_encoder.py index fbf9b16a3..9ba8e1768 100644 --- a/egs/ljspeech/tts/vits/text_encoder.py +++ b/egs/ljspeech/tts/vits/text_encoder.py @@ -45,6 +45,7 @@ class TextEncoder(torch.nn.Module): d_model: int = 192, num_heads: int = 2, dim_feedforward: int = 768, + cnn_module_kernel: int = 5, num_layers: int = 6, dropout: float = 0.1, ): @@ -55,6 +56,7 @@ class TextEncoder(torch.nn.Module): d_model (int): attention dimension num_heads (int): number of attention heads dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size num_layers (int): number of encoder layers dropout (float): dropout rate """ @@ -69,6 +71,7 @@ class TextEncoder(torch.nn.Module): d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, num_layers=num_layers, dropout=dropout, ) @@ -119,6 +122,7 @@ class Transformer(nn.Module): d_model (int): attention dimension num_heads (int): number of attention heads dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size num_layers (int): number of encoder layers dropout (float): dropout rate """ @@ -128,6 +132,7 @@ class Transformer(nn.Module): d_model: int = 192, num_heads: int = 2, dim_feedforward: int = 768, + cnn_module_kernel: int = 5, num_layers: int = 6, dropout: float = 0.1, ) -> None: @@ -142,6 +147,7 @@ class Transformer(nn.Module): d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, dropout=dropout, ) self.encoder = TransformerEncoder(encoder_layer, num_layers) @@ -187,12 +193,22 @@ class TransformerEncoderLayer(nn.Module): d_model: int, num_heads: int, dim_feedforward: int, + cnn_module_kernel: int, dropout: float = 0.1, ) -> None: super(TransformerEncoderLayer, self).__init__() + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), @@ -200,10 +216,13 @@ class TransformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.ff_scale = 0.5 self.dropout = nn.Dropout(dropout) def forward( @@ -220,6 +239,9 @@ class TransformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ + # macaron style feed-forward module + src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + # multi-head self-attention module src_attn = self.self_attn( self.norm_mha(src), @@ -228,6 +250,9 @@ class TransformerEncoderLayer(nn.Module): ) src = src + self.dropout(src_attn) + # convolution module + src = src + self.dropout(self.conv_module(self.norm_conv(src))) + # feed-forward module src = src + self.dropout(self.feed_forward(self.norm_ff(src))) @@ -508,6 +533,95 @@ class RelPositionMultiheadAttention(nn.Module): return attn_output +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + class Swish(nn.Module): """Construct an Swish object.""" diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py index 441e915df..27d9b4c7a 100644 --- a/egs/ljspeech/tts/vits/vits.py +++ b/egs/ljspeech/tts/vits/vits.py @@ -61,6 +61,7 @@ class VITS(nn.Module): "segment_size": 32, "text_encoder_attention_heads": 2, "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, "text_encoder_blocks": 6, "text_encoder_dropout_rate": 0.1, "decoder_kernel_size": 7,