mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
use Conformer as text encoder
This commit is contained in:
parent
b719581e2f
commit
8d09f8e6bf
@ -44,6 +44,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
segment_size: int = 32,
|
segment_size: int = 32,
|
||||||
text_encoder_attention_heads: int = 2,
|
text_encoder_attention_heads: int = 2,
|
||||||
text_encoder_ffn_expand: int = 4,
|
text_encoder_ffn_expand: int = 4,
|
||||||
|
text_encoder_cnn_module_kernel: int = 5,
|
||||||
text_encoder_blocks: int = 6,
|
text_encoder_blocks: int = 6,
|
||||||
text_encoder_dropout_rate: float = 0.1,
|
text_encoder_dropout_rate: float = 0.1,
|
||||||
decoder_kernel_size: int = 7,
|
decoder_kernel_size: int = 7,
|
||||||
@ -89,6 +90,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
of text encoder.
|
of text encoder.
|
||||||
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
|
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
|
||||||
of text encoder.
|
of text encoder.
|
||||||
|
text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
|
||||||
text_encoder_blocks (int): Number of conformer blocks in text encoder.
|
text_encoder_blocks (int): Number of conformer blocks in text encoder.
|
||||||
text_encoder_dropout_rate (float): Dropout rate in conformer block of
|
text_encoder_dropout_rate (float): Dropout rate in conformer block of
|
||||||
text encoder.
|
text encoder.
|
||||||
@ -135,6 +137,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
d_model=hidden_channels,
|
d_model=hidden_channels,
|
||||||
num_heads=text_encoder_attention_heads,
|
num_heads=text_encoder_attention_heads,
|
||||||
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
|
dim_feedforward=hidden_channels * text_encoder_ffn_expand,
|
||||||
|
cnn_module_kernel=text_encoder_cnn_module_kernel,
|
||||||
num_layers=text_encoder_blocks,
|
num_layers=text_encoder_blocks,
|
||||||
dropout=text_encoder_dropout_rate,
|
dropout=text_encoder_dropout_rate,
|
||||||
)
|
)
|
||||||
|
@ -103,11 +103,13 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from train2 import get_model, get_params
|
from train import get_model, get_params, prepare_input
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -124,7 +126,6 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from tts_datamodule import LJSpeechTtsDataModule
|
from tts_datamodule import LJSpeechTtsDataModule
|
||||||
from utils import prepare_token_batch
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -169,6 +170,13 @@ def get_parser():
|
|||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -176,6 +184,7 @@ def infer_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -236,10 +245,16 @@ def infer_dataset(
|
|||||||
# We only want one background worker so that serialization is deterministic.
|
# We only want one background worker so that serialization is deterministic.
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
|
|
||||||
text = batch["text"]
|
text = batch["text"]
|
||||||
tokens, tokens_lens = prepare_token_batch(text)
|
tokens = tokenizer.texts_to_token_ids(text)
|
||||||
|
tokens = k2.RaggedTensor(tokens)
|
||||||
|
row_splits = tokens.shape.row_splits(1)
|
||||||
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
|
# a tensor of shape (B, T)
|
||||||
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||||
|
|
||||||
audio = batch["audio"]
|
audio = batch["audio"]
|
||||||
audio_lens = batch["audio_lens"].tolist()
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
@ -296,6 +311,11 @@ def main():
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.blank_id
|
||||||
|
params.oov_id = tokenizer.oov_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -348,6 +368,7 @@ def main():
|
|||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save_results(
|
# save_results(
|
||||||
|
@ -45,6 +45,7 @@ class TextEncoder(torch.nn.Module):
|
|||||||
d_model: int = 192,
|
d_model: int = 192,
|
||||||
num_heads: int = 2,
|
num_heads: int = 2,
|
||||||
dim_feedforward: int = 768,
|
dim_feedforward: int = 768,
|
||||||
|
cnn_module_kernel: int = 5,
|
||||||
num_layers: int = 6,
|
num_layers: int = 6,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
@ -55,6 +56,7 @@ class TextEncoder(torch.nn.Module):
|
|||||||
d_model (int): attention dimension
|
d_model (int): attention dimension
|
||||||
num_heads (int): number of attention heads
|
num_heads (int): number of attention heads
|
||||||
dim_feedforward (int): feedforward dimention
|
dim_feedforward (int): feedforward dimention
|
||||||
|
cnn_module_kernel (int): convolution kernel size
|
||||||
num_layers (int): number of encoder layers
|
num_layers (int): number of encoder layers
|
||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
"""
|
"""
|
||||||
@ -69,6 +71,7 @@ class TextEncoder(torch.nn.Module):
|
|||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
dim_feedforward=dim_feedforward,
|
dim_feedforward=dim_feedforward,
|
||||||
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
@ -119,6 +122,7 @@ class Transformer(nn.Module):
|
|||||||
d_model (int): attention dimension
|
d_model (int): attention dimension
|
||||||
num_heads (int): number of attention heads
|
num_heads (int): number of attention heads
|
||||||
dim_feedforward (int): feedforward dimention
|
dim_feedforward (int): feedforward dimention
|
||||||
|
cnn_module_kernel (int): convolution kernel size
|
||||||
num_layers (int): number of encoder layers
|
num_layers (int): number of encoder layers
|
||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
"""
|
"""
|
||||||
@ -128,6 +132,7 @@ class Transformer(nn.Module):
|
|||||||
d_model: int = 192,
|
d_model: int = 192,
|
||||||
num_heads: int = 2,
|
num_heads: int = 2,
|
||||||
dim_feedforward: int = 768,
|
dim_feedforward: int = 768,
|
||||||
|
cnn_module_kernel: int = 5,
|
||||||
num_layers: int = 6,
|
num_layers: int = 6,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -142,6 +147,7 @@ class Transformer(nn.Module):
|
|||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
dim_feedforward=dim_feedforward,
|
dim_feedforward=dim_feedforward,
|
||||||
|
cnn_module_kernel=cnn_module_kernel,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
self.encoder = TransformerEncoder(encoder_layer, num_layers)
|
||||||
@ -187,12 +193,22 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
d_model: int,
|
d_model: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
dim_feedforward: int,
|
dim_feedforward: int,
|
||||||
|
cnn_module_kernel: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
|
nn.Linear(d_model, dim_feedforward),
|
||||||
|
Swish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(dim_feedforward, d_model),
|
||||||
|
)
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
|
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
|
||||||
|
|
||||||
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
Swish(),
|
Swish(),
|
||||||
@ -200,10 +216,13 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
|
||||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
|
||||||
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
|
|
||||||
|
self.ff_scale = 0.5
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -220,6 +239,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
|
||||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||||
"""
|
"""
|
||||||
|
# macaron style feed-forward module
|
||||||
|
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
|
||||||
|
|
||||||
# multi-head self-attention module
|
# multi-head self-attention module
|
||||||
src_attn = self.self_attn(
|
src_attn = self.self_attn(
|
||||||
self.norm_mha(src),
|
self.norm_mha(src),
|
||||||
@ -228,6 +250,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
src = src + self.dropout(src_attn)
|
src = src + self.dropout(src_attn)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
src = src + self.dropout(self.conv_module(self.norm_conv(src)))
|
||||||
|
|
||||||
# feed-forward module
|
# feed-forward module
|
||||||
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
|
src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
|
||||||
|
|
||||||
@ -508,6 +533,95 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class ConvolutionModule(nn.Module):
|
||||||
|
"""ConvolutionModule in Conformer model.
|
||||||
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): The number of channels of conv layers.
|
||||||
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Construct an ConvolutionModule object."""
|
||||||
|
super(ConvolutionModule, self).__init__()
|
||||||
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
2 * channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
groups=channels,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(channels)
|
||||||
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self.activation = Swish()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor (#time, batch, channels).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# exchange the temporal dimension and the feature dimension
|
||||||
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|
||||||
|
# GLU mechanism
|
||||||
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
|
# 1D Depthwise Conv
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class Swish(nn.Module):
|
||||||
"""Construct an Swish object."""
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ class VITS(nn.Module):
|
|||||||
"segment_size": 32,
|
"segment_size": 32,
|
||||||
"text_encoder_attention_heads": 2,
|
"text_encoder_attention_heads": 2,
|
||||||
"text_encoder_ffn_expand": 4,
|
"text_encoder_ffn_expand": 4,
|
||||||
|
"text_encoder_cnn_module_kernel": 5,
|
||||||
"text_encoder_blocks": 6,
|
"text_encoder_blocks": 6,
|
||||||
"text_encoder_dropout_rate": 0.1,
|
"text_encoder_dropout_rate": 0.1,
|
||||||
"decoder_kernel_size": 7,
|
"decoder_kernel_size": 7,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user