mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
1101 lines
36 KiB
Python
1101 lines
36 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
|
# Apache 2.0
|
|
|
|
import math
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import k2
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from icefall.utils import get_texts
|
|
|
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
|
Supervisions = Dict[str, Tensor]
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
"""
|
|
Args:
|
|
num_features (int): Number of input features
|
|
num_classes (int): Number of output classes
|
|
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
|
d_model (int): attention dimension
|
|
nhead (int): number of head
|
|
dim_feedforward (int): feedforward dimention
|
|
num_encoder_layers (int): number of encoder layers
|
|
num_decoder_layers (int): number of decoder layers
|
|
dropout (float): dropout rate
|
|
normalize_before (bool): whether to use layer_norm before the first block.
|
|
vgg_frontend (bool): whether to use vgg frontend.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
num_classes: int,
|
|
subsampling_factor: int = 4,
|
|
d_model: int = 256,
|
|
nhead: int = 4,
|
|
dim_feedforward: int = 2048,
|
|
num_encoder_layers: int = 12,
|
|
num_decoder_layers: int = 6,
|
|
dropout: float = 0.1,
|
|
normalize_before: bool = True,
|
|
vgg_frontend: bool = False,
|
|
mmi_loss: bool = True,
|
|
use_feat_batchnorm: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.use_feat_batchnorm = use_feat_batchnorm
|
|
if use_feat_batchnorm:
|
|
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
|
|
|
self.num_features = num_features
|
|
self.num_classes = num_classes
|
|
self.subsampling_factor = subsampling_factor
|
|
if subsampling_factor != 4:
|
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
|
|
|
self.encoder_embed = (
|
|
VggSubsampling(num_features, d_model)
|
|
if vgg_frontend
|
|
else Conv2dSubsampling(num_features, d_model)
|
|
)
|
|
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
|
|
|
encoder_layer = TransformerEncoderLayer(
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward,
|
|
dropout,
|
|
normalize_before=normalize_before,
|
|
)
|
|
|
|
if normalize_before:
|
|
encoder_norm = nn.LayerNorm(d_model)
|
|
else:
|
|
encoder_norm = None
|
|
|
|
self.encoder = nn.TransformerEncoder(
|
|
encoder_layer, num_encoder_layers, encoder_norm
|
|
)
|
|
|
|
self.encoder_output_layer = nn.Sequential(
|
|
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
|
)
|
|
|
|
if num_decoder_layers > 0:
|
|
if mmi_loss:
|
|
self.decoder_num_class = (
|
|
self.num_classes + 1
|
|
) # +1 for the sos/eos symbol
|
|
else:
|
|
self.decoder_num_class = (
|
|
self.num_classes
|
|
) # bpe model already has sos/eos symbol
|
|
|
|
self.decoder_embed = nn.Embedding(self.decoder_num_class, d_model)
|
|
self.decoder_pos = PositionalEncoding(d_model, dropout)
|
|
|
|
decoder_layer = TransformerDecoderLayer(
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward,
|
|
dropout,
|
|
normalize_before=normalize_before,
|
|
)
|
|
|
|
if normalize_before:
|
|
decoder_norm = nn.LayerNorm(d_model)
|
|
else:
|
|
decoder_norm = None
|
|
|
|
self.decoder = nn.TransformerDecoder(
|
|
decoder_layer, num_decoder_layers, decoder_norm
|
|
)
|
|
|
|
self.decoder_output_layer = torch.nn.Linear(
|
|
d_model, self.decoder_num_class
|
|
)
|
|
|
|
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
|
|
else:
|
|
self.decoder_criterion = None
|
|
|
|
def forward(
|
|
self, x: Tensor, supervision: Optional[Supervisions] = None
|
|
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
|
"""
|
|
Args:
|
|
x: Tensor of dimension (batch_size, num_features, input_length).
|
|
supervision: Supervison in lhotse format, get from batch['supervisions']
|
|
|
|
Returns:
|
|
Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length).
|
|
Tensor: Before linear layer tensor of dimension (input_length, batch_size, d_model).
|
|
Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None.
|
|
|
|
"""
|
|
if self.use_feat_batchnorm:
|
|
x = self.feat_batchnorm(x)
|
|
encoder_memory, memory_mask = self.encode(x, supervision)
|
|
x = self.encoder_output(encoder_memory)
|
|
return x, encoder_memory, memory_mask
|
|
|
|
def encode(
|
|
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
"""
|
|
Args:
|
|
x: Tensor of dimension (batch_size, num_features, input_length).
|
|
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
|
|
|
|
Returns:
|
|
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
|
Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None.
|
|
"""
|
|
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
|
|
|
|
x = self.encoder_embed(x)
|
|
x = self.encoder_pos(x)
|
|
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
|
mask = encoder_padding_mask(x.size(0), supervisions)
|
|
mask = mask.to(x.device) if mask != None else None
|
|
x = self.encoder(x, src_key_padding_mask=mask) # (T, B, F)
|
|
|
|
return x, mask
|
|
|
|
def encoder_output(self, x: Tensor) -> Tensor:
|
|
"""
|
|
Args:
|
|
x: Tensor of dimension (input_length, batch_size, d_model).
|
|
|
|
Returns:
|
|
Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length).
|
|
"""
|
|
x = self.encoder_output_layer(x).permute(
|
|
1, 2, 0
|
|
) # (T, B, F) ->(B, F, T)
|
|
x = nn.functional.log_softmax(x, dim=1) # (B, F, T)
|
|
return x
|
|
|
|
def decoder_forward(
|
|
self,
|
|
x: Tensor,
|
|
encoder_mask: Tensor,
|
|
supervision: Supervisions = None,
|
|
graph_compiler: object = None,
|
|
token_ids: List[int] = None,
|
|
) -> Tensor:
|
|
"""
|
|
Args:
|
|
x: Tensor of dimension (input_length, batch_size, d_model).
|
|
encoder_mask: Mask tensor of dimension (batch_size, input_length)
|
|
supervision: Supervison in lhotse format, get from batch['supervisions']
|
|
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
|
|
, graph_compiler.words and graph_compiler.oov
|
|
|
|
Returns:
|
|
Tensor: Decoder loss.
|
|
"""
|
|
if supervision is not None and graph_compiler is not None:
|
|
batch_text = get_normal_transcripts(
|
|
supervision, graph_compiler.lexicon.words, graph_compiler.oov
|
|
)
|
|
ys_in_pad, ys_out_pad = add_sos_eos(
|
|
batch_text,
|
|
graph_compiler.L_inv,
|
|
self.decoder_num_class - 1,
|
|
self.decoder_num_class - 1,
|
|
)
|
|
elif token_ids is not None:
|
|
# speical token ids:
|
|
# <blank> 0
|
|
# <UNK> 1
|
|
# <sos/eos> self.decoder_num_class - 1
|
|
sos_id = self.decoder_num_class - 1
|
|
eos_id = self.decoder_num_class - 1
|
|
_sos = torch.tensor([sos_id])
|
|
_eos = torch.tensor([eos_id])
|
|
ys_in = [
|
|
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
|
|
]
|
|
ys_out = [
|
|
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
|
|
]
|
|
ys_in_pad = pad_list(ys_in, eos_id)
|
|
ys_out_pad = pad_list(ys_out, -1)
|
|
|
|
else:
|
|
raise ValueError("Invalid input for decoder self attetion")
|
|
|
|
ys_in_pad = ys_in_pad.to(x.device)
|
|
ys_out_pad = ys_out_pad.to(x.device)
|
|
|
|
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
|
x.device
|
|
)
|
|
|
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
|
|
|
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
|
tgt = self.decoder_pos(tgt)
|
|
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
|
pred_pad = self.decoder(
|
|
tgt=tgt,
|
|
memory=x,
|
|
tgt_mask=tgt_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=encoder_mask,
|
|
) # (T, B, F)
|
|
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
|
|
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
|
|
|
|
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
|
|
|
return decoder_loss
|
|
|
|
def decoder_nll(
|
|
self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None
|
|
) -> Tensor:
|
|
"""
|
|
Args:
|
|
x: encoder-output, Tensor of dimension (input_length, batch_size, d_model).
|
|
encoder_mask: Mask tensor of dimension (batch_size, input_length)
|
|
token_ids: n-best list extracted from lattice before rescore
|
|
|
|
Returns:
|
|
Tensor: negative log-likelihood.
|
|
"""
|
|
# The common part between this fuction and decoder_forward could be
|
|
# extracted as a seperated function.
|
|
if token_ids is not None:
|
|
# speical token ids:
|
|
# <blank> 0
|
|
# <UNK> 1
|
|
# <sos/eos> self.decoder_num_class - 1
|
|
sos_id = self.decoder_num_class - 1
|
|
eos_id = self.decoder_num_class - 1
|
|
_sos = torch.tensor([sos_id])
|
|
_eos = torch.tensor([eos_id])
|
|
ys_in = [
|
|
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
|
|
]
|
|
ys_out = [
|
|
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
|
|
]
|
|
ys_in_pad = pad_list(ys_in, eos_id)
|
|
ys_out_pad = pad_list(ys_out, -1)
|
|
else:
|
|
raise ValueError("Invalid input for decoder self attetion")
|
|
|
|
ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64)
|
|
ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64)
|
|
|
|
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
|
x.device
|
|
)
|
|
|
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
|
|
|
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
|
tgt = self.decoder_pos(tgt)
|
|
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
|
pred_pad = self.decoder(
|
|
tgt=tgt,
|
|
memory=x,
|
|
tgt_mask=tgt_mask,
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
memory_key_padding_mask=encoder_mask,
|
|
) # (T, B, F)
|
|
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
|
|
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
|
|
# nll: negative log-likelihood
|
|
nll = torch.nn.functional.cross_entropy(
|
|
pred_pad.view(-1, self.decoder_num_class),
|
|
ys_out_pad.view(-1),
|
|
ignore_index=-1,
|
|
reduction="none",
|
|
)
|
|
|
|
nll = nll.view(pred_pad.shape[0], -1)
|
|
|
|
return nll
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
"""
|
|
Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before,
|
|
i.e., use layer_norm before the first block.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
|
normalize_before: whether to use layer_norm before the first block.
|
|
|
|
Examples::
|
|
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> out = encoder_layer(src)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
nhead: int,
|
|
dim_feedforward: int = 2048,
|
|
dropout: float = 0.1,
|
|
activation: str = "relu",
|
|
normalize_before: bool = True,
|
|
) -> None:
|
|
super(TransformerEncoderLayer, self).__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
|
self.normalize_before = normalize_before
|
|
|
|
def __setstate__(self, state):
|
|
if "activation" not in state:
|
|
state["activation"] = nn.functional.relu
|
|
super(TransformerEncoderLayer, self).__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
src: Tensor,
|
|
src_mask: Optional[Tensor] = None,
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
) -> Tensor:
|
|
"""
|
|
Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
src: the sequence to the encoder layer (required).
|
|
src_mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
|
|
Shape:
|
|
src: (S, N, E).
|
|
src_mask: (S, S).
|
|
src_key_padding_mask: (N, S).
|
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
|
"""
|
|
residual = src
|
|
if self.normalize_before:
|
|
src = self.norm1(src)
|
|
src2 = self.self_attn(
|
|
src,
|
|
src,
|
|
src,
|
|
attn_mask=src_mask,
|
|
key_padding_mask=src_key_padding_mask,
|
|
)[0]
|
|
src = residual + self.dropout1(src2)
|
|
if not self.normalize_before:
|
|
src = self.norm1(src)
|
|
|
|
residual = src
|
|
if self.normalize_before:
|
|
src = self.norm2(src)
|
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
src = residual + self.dropout2(src2)
|
|
if not self.normalize_before:
|
|
src = self.norm2(src)
|
|
return src
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
"""
|
|
Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before,
|
|
i.e., use layer_norm before the first block.
|
|
|
|
Args:
|
|
d_model: the number of expected features in the input (required).
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
dropout: the dropout value (default=0.1).
|
|
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
|
|
|
Examples::
|
|
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
>>> memory = torch.rand(10, 32, 512)
|
|
>>> tgt = torch.rand(20, 32, 512)
|
|
>>> out = decoder_layer(tgt, memory)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
nhead: int,
|
|
dim_feedforward: int = 2048,
|
|
dropout: float = 0.1,
|
|
activation: str = "relu",
|
|
normalize_before: bool = True,
|
|
) -> None:
|
|
super(TransformerDecoderLayer, self).__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
|
self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
|
self.normalize_before = normalize_before
|
|
|
|
def __setstate__(self, state):
|
|
if "activation" not in state:
|
|
state["activation"] = nn.functional.relu
|
|
super(TransformerDecoderLayer, self).__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
tgt: Tensor,
|
|
memory: Tensor,
|
|
tgt_mask: Optional[Tensor] = None,
|
|
memory_mask: Optional[Tensor] = None,
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
) -> Tensor:
|
|
"""Pass the inputs (and mask) through the decoder layer.
|
|
|
|
Args:
|
|
tgt: the sequence to the decoder layer (required).
|
|
memory: the sequence from the last layer of the encoder (required).
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
|
|
Shape:
|
|
tgt: (T, N, E).
|
|
memory: (S, N, E).
|
|
tgt_mask: (T, T).
|
|
memory_mask: (T, S).
|
|
tgt_key_padding_mask: (N, T).
|
|
memory_key_padding_mask: (N, S).
|
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
|
"""
|
|
residual = tgt
|
|
if self.normalize_before:
|
|
tgt = self.norm1(tgt)
|
|
tgt2 = self.self_attn(
|
|
tgt,
|
|
tgt,
|
|
tgt,
|
|
attn_mask=tgt_mask,
|
|
key_padding_mask=tgt_key_padding_mask,
|
|
)[0]
|
|
tgt = residual + self.dropout1(tgt2)
|
|
if not self.normalize_before:
|
|
tgt = self.norm1(tgt)
|
|
|
|
residual = tgt
|
|
if self.normalize_before:
|
|
tgt = self.norm2(tgt)
|
|
tgt2 = self.src_attn(
|
|
tgt,
|
|
memory,
|
|
memory,
|
|
attn_mask=memory_mask,
|
|
key_padding_mask=memory_key_padding_mask,
|
|
)[0]
|
|
tgt = residual + self.dropout2(tgt2)
|
|
if not self.normalize_before:
|
|
tgt = self.norm2(tgt)
|
|
|
|
residual = tgt
|
|
if self.normalize_before:
|
|
tgt = self.norm3(tgt)
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
tgt = residual + self.dropout3(tgt2)
|
|
if not self.normalize_before:
|
|
tgt = self.norm3(tgt)
|
|
return tgt
|
|
|
|
|
|
def _get_activation_fn(activation: str):
|
|
if activation == "relu":
|
|
return nn.functional.relu
|
|
elif activation == "gelu":
|
|
return nn.functional.gelu
|
|
|
|
raise RuntimeError(
|
|
"activation should be relu/gelu, not {}".format(activation)
|
|
)
|
|
|
|
|
|
class Conv2dSubsampling(nn.Module):
|
|
"""Convolutional 2D subsampling (to 1/4 length).
|
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py
|
|
|
|
Args:
|
|
idim: Input dimension.
|
|
odim: Output dimension.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim: int, odim: int) -> None:
|
|
"""Construct a Conv2dSubsampling object."""
|
|
super(Conv2dSubsampling, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
|
),
|
|
nn.ReLU(),
|
|
nn.Conv2d(
|
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
|
),
|
|
nn.ReLU(),
|
|
)
|
|
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor of dimension (batch_size, input_length, d_model).
|
|
where time' = time // 4.
|
|
|
|
"""
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.conv(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
return x
|
|
|
|
|
|
class VggSubsampling(nn.Module):
|
|
"""Trying to follow the setup described here https://arxiv.org/pdf/1910.09799.pdf
|
|
This paper is not 100% explicit so I am guessing to some extent,
|
|
and trying to compare with other VGG implementations.
|
|
|
|
Args:
|
|
idim: Input dimension.
|
|
odim: Output dimension.
|
|
|
|
"""
|
|
|
|
def __init__(self, idim: int, odim: int) -> None:
|
|
"""Construct a VggSubsampling object. This uses 2 VGG blocks with 2
|
|
Conv2d layers each, subsampling its input by a factor of 4 in the
|
|
time dimensions.
|
|
|
|
Args:
|
|
idim: Number of features at input, e.g. 40 or 80 for MFCC
|
|
(will be treated as the image height).
|
|
odim: Output dimension (number of features), e.g. 256
|
|
"""
|
|
super(VggSubsampling, self).__init__()
|
|
|
|
cur_channels = 1
|
|
layers = []
|
|
block_dims = [32, 64]
|
|
|
|
# The decision to use padding=1 for the 1st convolution, then padding=0
|
|
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
|
# a back-compatibility concern so that the number of frames at the
|
|
# output would be equal to:
|
|
# (((T-1)//2)-1)//2.
|
|
# We can consider changing this by using padding=1 on the 2nd convolution,
|
|
# so the num-frames at the output would be T//4.
|
|
for block_dim in block_dims:
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=cur_channels,
|
|
out_channels=block_dim,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=1,
|
|
)
|
|
)
|
|
layers.append(torch.nn.ReLU())
|
|
layers.append(
|
|
torch.nn.Conv2d(
|
|
in_channels=block_dim,
|
|
out_channels=block_dim,
|
|
kernel_size=3,
|
|
padding=0,
|
|
stride=1,
|
|
)
|
|
)
|
|
layers.append(
|
|
torch.nn.MaxPool2d(
|
|
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
|
)
|
|
)
|
|
cur_channels = block_dim
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
self.out = nn.Linear(
|
|
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
|
)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Subsample x.
|
|
|
|
Args:
|
|
x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim).
|
|
|
|
Returns:
|
|
torch.Tensor: Subsampled tensor of dimension (batch_size, input_length', d_model).
|
|
where input_length' == (((input_length - 1) // 2) - 1) // 2
|
|
|
|
"""
|
|
x = x.unsqueeze(1) # (b, c, t, f)
|
|
x = self.layers(x)
|
|
b, c, t, f = x.size()
|
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
|
return x
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""
|
|
Positional encoding.
|
|
|
|
Args:
|
|
d_model: Embedding dimension.
|
|
dropout: Dropout rate.
|
|
max_len: Maximum input length.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self, d_model: int, dropout: float = 0.1, max_len: int = 5000
|
|
) -> None:
|
|
"""Construct an PositionalEncoding object."""
|
|
super(PositionalEncoding, self).__init__()
|
|
self.d_model = d_model
|
|
self.xscale = math.sqrt(self.d_model)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.pe = None
|
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
|
|
|
def extend_pe(self, x: Tensor) -> None:
|
|
"""Reset the positional encodings."""
|
|
if self.pe is not None:
|
|
if self.pe.size(1) >= x.size(1):
|
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
return
|
|
pe = torch.zeros(x.size(1), self.d_model)
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
* -(math.log(10000.0) / self.d_model)
|
|
)
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""
|
|
Add positional encoding.
|
|
|
|
Args:
|
|
x: Input tensor of dimention (batch_size, input_length, d_model).
|
|
|
|
Returns:
|
|
torch.Tensor: Encoded tensor of dimention (batch_size, input_length, d_model).
|
|
|
|
"""
|
|
self.extend_pe(x)
|
|
x = x * self.xscale + self.pe[:, : x.size(1)]
|
|
return self.dropout(x)
|
|
|
|
|
|
class Noam(object):
|
|
"""
|
|
Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
|
|
|
|
Args:
|
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
|
model_size: attention dimension of the transformer model
|
|
factor: learning rate factor
|
|
warm_step: warmup steps
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
model_size: int = 256,
|
|
factor: float = 10.0,
|
|
warm_step: int = 25000,
|
|
weight_decay=0,
|
|
) -> None:
|
|
"""Construct an Noam object."""
|
|
self.optimizer = torch.optim.Adam(
|
|
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
|
)
|
|
self._step = 0
|
|
self.warmup = warm_step
|
|
self.factor = factor
|
|
self.model_size = model_size
|
|
self._rate = 0
|
|
|
|
@property
|
|
def param_groups(self):
|
|
"""Return param_groups."""
|
|
return self.optimizer.param_groups
|
|
|
|
def step(self):
|
|
"""Update parameters and rate."""
|
|
self._step += 1
|
|
rate = self.rate()
|
|
for p in self.optimizer.param_groups:
|
|
p["lr"] = rate
|
|
self._rate = rate
|
|
self.optimizer.step()
|
|
|
|
def rate(self, step=None):
|
|
"""Implement `lrate` above."""
|
|
if step is None:
|
|
step = self._step
|
|
return (
|
|
self.factor
|
|
* self.model_size ** (-0.5)
|
|
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
|
)
|
|
|
|
def zero_grad(self):
|
|
"""Reset gradient."""
|
|
self.optimizer.zero_grad()
|
|
|
|
def state_dict(self):
|
|
"""Return state_dict."""
|
|
return {
|
|
"_step": self._step,
|
|
"warmup": self.warmup,
|
|
"factor": self.factor,
|
|
"model_size": self.model_size,
|
|
"_rate": self._rate,
|
|
"optimizer": self.optimizer.state_dict(),
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Load state_dict."""
|
|
for key, value in state_dict.items():
|
|
if key == "optimizer":
|
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
else:
|
|
setattr(self, key, value)
|
|
|
|
|
|
class LabelSmoothingLoss(nn.Module):
|
|
"""
|
|
Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w)
|
|
and p_{prob. computed by model}(w) is minimized.
|
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py
|
|
|
|
Args:
|
|
size: the number of class
|
|
padding_idx: padding_idx: ignored class id
|
|
smoothing: smoothing rate (0.0 means the conventional CE)
|
|
normalize_length: normalize loss by sequence length if True
|
|
criterion: loss function to be smoothed
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
padding_idx: int = -1,
|
|
smoothing: float = 0.1,
|
|
normalize_length: bool = False,
|
|
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
|
) -> None:
|
|
"""Construct an LabelSmoothingLoss object."""
|
|
super(LabelSmoothingLoss, self).__init__()
|
|
self.criterion = criterion
|
|
self.padding_idx = padding_idx
|
|
assert 0.0 < smoothing <= 1.0
|
|
self.confidence = 1.0 - smoothing
|
|
self.smoothing = smoothing
|
|
self.size = size
|
|
self.true_dist = None
|
|
self.normalize_length = normalize_length
|
|
|
|
def forward(self, x: Tensor, target: Tensor) -> Tensor:
|
|
"""
|
|
Compute loss between x and target.
|
|
|
|
Args:
|
|
x: prediction of dimention (batch_size, input_length, number_of_classes).
|
|
target: target masked with self.padding_id of dimention (batch_size, input_length).
|
|
|
|
Returns:
|
|
torch.Tensor: scalar float value
|
|
"""
|
|
assert x.size(2) == self.size
|
|
batch_size = x.size(0)
|
|
x = x.view(-1, self.size)
|
|
target = target.view(-1)
|
|
with torch.no_grad():
|
|
true_dist = x.clone()
|
|
true_dist.fill_(self.smoothing / (self.size - 1))
|
|
ignore = target == self.padding_idx # (B,)
|
|
total = len(target) - ignore.sum().item()
|
|
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
|
denom = total if self.normalize_length else batch_size
|
|
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
|
|
|
|
|
def encoder_padding_mask(
|
|
max_len: int, supervisions: Optional[Supervisions] = None
|
|
) -> Optional[Tensor]:
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
Args:
|
|
max_len: maximum length of input features
|
|
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
|
|
|
|
Returns:
|
|
Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices.
|
|
"""
|
|
if supervisions is None:
|
|
return None
|
|
|
|
supervision_segments = torch.stack(
|
|
(
|
|
supervisions["sequence_idx"],
|
|
supervisions["start_frame"],
|
|
supervisions["num_frames"],
|
|
),
|
|
1,
|
|
).to(torch.int32)
|
|
|
|
lengths = [
|
|
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
|
]
|
|
for idx in range(supervision_segments.size(0)):
|
|
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
|
sequence_idx = supervision_segments[idx, 0].item()
|
|
start_frame = supervision_segments[idx, 1].item()
|
|
num_frames = supervision_segments[idx, 2].item()
|
|
lengths[sequence_idx] = start_frame + num_frames
|
|
|
|
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
|
|
bs = int(len(lengths))
|
|
seq_range = torch.arange(0, max_len, dtype=torch.int64)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
|
# Note: TorchScript doesn't implement Tensor.new()
|
|
seq_length_expand = torch.tensor(
|
|
lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
|
|
).unsqueeze(-1)
|
|
mask = seq_range_expand >= seq_length_expand
|
|
|
|
return mask
|
|
|
|
|
|
def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor:
|
|
"""Generate a length mask for input. The masked position are filled with bool(True),
|
|
Unmasked positions are filled with bool(False).
|
|
|
|
Args:
|
|
ys_pad: padded tensor of dimension (batch_size, input_length).
|
|
ignore_id: the ignored number (the padding number) in ys_pad
|
|
|
|
Returns:
|
|
Tensor: a mask tensor of dimension (batch_size, input_length).
|
|
"""
|
|
ys_mask = ys_pad == ignore_id
|
|
return ys_mask
|
|
|
|
|
|
def get_normal_transcripts(
|
|
supervision: Supervisions, words: k2.SymbolTable, oov: str = "<UNK>"
|
|
) -> List[List[int]]:
|
|
"""Get normal transcripts (1 input recording has 1 transcript) from lhotse cut format.
|
|
Achieved by concatenate the transcripts corresponding to the same recording.
|
|
|
|
Args:
|
|
supervision : Supervison in lhotse format, i.e., batch['supervisions']
|
|
words: The word symbol table.
|
|
oov: Out of vocabulary word.
|
|
|
|
Returns:
|
|
List[List[int]]: List of concatenated transcripts, length is batch_size
|
|
"""
|
|
|
|
texts = [
|
|
[token if token in words else oov for token in text.split(" ")]
|
|
for text in supervision["text"]
|
|
]
|
|
texts_ids = [[words[token] for token in text] for text in texts]
|
|
|
|
batch_text = [
|
|
[] for _ in range(int(supervision["sequence_idx"].max().item()) + 1)
|
|
]
|
|
for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids):
|
|
batch_text[sequence_idx] = batch_text[sequence_idx] + text
|
|
return batch_text
|
|
|
|
|
|
def generate_square_subsequent_mask(sz: int) -> Tensor:
|
|
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
|
Unmasked positions are filled with float(0.0).
|
|
|
|
Args:
|
|
sz: mask size
|
|
|
|
Returns:
|
|
Tensor: a square mask of dimension (sz, sz)
|
|
"""
|
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
mask = (
|
|
mask.float()
|
|
.masked_fill(mask == 0, float("-inf"))
|
|
.masked_fill(mask == 1, float(0.0))
|
|
)
|
|
return mask
|
|
|
|
|
|
def add_sos_eos(
|
|
ys: List[List[int]],
|
|
lexicon: k2.Fsa,
|
|
sos: int,
|
|
eos: int,
|
|
ignore_id: int = -1,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
"""Add <sos> and <eos> labels.
|
|
|
|
Args:
|
|
ys: batch of unpadded target sequences
|
|
lexicon: Its labels are words, while its aux_labels are phones.
|
|
sos: index of <sos>
|
|
eos: index of <eos>
|
|
ignore_id: index of padding
|
|
|
|
Returns:
|
|
Tensor: Input of transformer decoder. Padded tensor of dimention (batch_size, max_length).
|
|
Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length).
|
|
"""
|
|
|
|
_sos = torch.tensor([sos])
|
|
_eos = torch.tensor([eos])
|
|
ys = get_hierarchical_targets(ys, lexicon)
|
|
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
|
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
|
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
|
|
|
|
|
|
def pad_list(ys: List[Tensor], pad_value: float) -> Tensor:
|
|
"""Perform padding for the list of tensors.
|
|
|
|
Args:
|
|
ys: List of tensors. len(ys) = batch_size.
|
|
pad_value: Value for padding.
|
|
|
|
Returns:
|
|
Tensor: Padded tensor (batch_size, max_length, `*`).
|
|
|
|
Examples:
|
|
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
|
>>> x
|
|
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
|
>>> pad_list(x, 0)
|
|
tensor([[1., 1., 1., 1.],
|
|
[1., 1., 0., 0.],
|
|
[1., 0., 0., 0.]])
|
|
|
|
"""
|
|
n_batch = len(ys)
|
|
max_len = max(x.size(0) for x in ys)
|
|
pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value)
|
|
|
|
for i in range(n_batch):
|
|
pad[i, : ys[i].size(0)] = ys[i]
|
|
|
|
return pad
|
|
|
|
|
|
def get_hierarchical_targets(
|
|
ys: List[List[int]], lexicon: k2.Fsa
|
|
) -> List[Tensor]:
|
|
"""Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).
|
|
|
|
Args:
|
|
ys: Word level transcripts.
|
|
lexicon: Its labels are words, while its aux_labels are phones.
|
|
|
|
Returns:
|
|
List[Tensor]: Phone level transcripts.
|
|
|
|
"""
|
|
|
|
if lexicon is None:
|
|
return ys
|
|
else:
|
|
L_inv = lexicon
|
|
|
|
n_batch = len(ys)
|
|
device = L_inv.device
|
|
|
|
transcripts = k2.create_fsa_vec(
|
|
[k2.linear_fsa(x, device=device) for x in ys]
|
|
)
|
|
transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)
|
|
|
|
transcripts_lexicon = k2.intersect(
|
|
L_inv, transcripts_with_self_loops, treat_epsilons_specially=False
|
|
)
|
|
# Don't call invert_() above because we want to return phone IDs,
|
|
# which is the `aux_labels` of transcripts_lexicon
|
|
transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
|
|
transcripts_lexicon = k2.top_sort(transcripts_lexicon)
|
|
|
|
transcripts_lexicon = k2.shortest_path(
|
|
transcripts_lexicon, use_double_scores=True
|
|
)
|
|
|
|
ys = get_texts(transcripts_lexicon)
|
|
ys = [torch.tensor(y) for y in ys]
|
|
|
|
return ys
|
|
|
|
|
|
def test_transformer():
|
|
t = Transformer(40, 1281)
|
|
T = 200
|
|
f = torch.rand(31, 40, T)
|
|
g, _, _ = t(f)
|
|
assert g.shape == (31, 1281, (((T - 1) // 2) - 1) // 2)
|
|
|
|
|
|
def main():
|
|
test_transformer()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|