from local

This commit is contained in:
dohe0342 2023-02-02 11:36:43 +09:00
parent 203cedb453
commit ab44c2d54a
3 changed files with 1 additions and 128 deletions

View File

@ -31,7 +31,7 @@ from scaling import (
) )
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask from transformer import Supervisions, Transformer, encoder_padding_mask, TransformerEncoder, TransformerEncoder
class Conformer(Transformer): class Conformer(Transformer):
@ -161,133 +161,6 @@ class Conformer(Transformer):
return x, mask return x, mask
class TransfEnc(Transformer):
"""
Args:
num_features (int): Number of input features
num_classes (int): Number of output classes
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
layer_dropout (float): layer-dropout rate.
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
group_num: int = 0,
) -> None:
super(TransfEnc, self).__init__(
num_features=num_features,
num_classes=num_classes,
subsampling_factor=subsampling_factor,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout,
layer_dropout=layer_dropout,
)
self.num_features = num_features
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
layer_dropout,
cnn_module_kernel,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.group_num = group_num
if self.group_num != 0:
self.group_layer_num = int(num_encoder_layers // self.group_num)
self.alpha = nn.Parameter(torch.rand(self.group_num))
self.sigmoid = nn.Sigmoid()
self.layer_norm = nn.LayerNorm(d_model)
def run_encoder(
self,
x: torch.Tensor,
supervisions: Optional[Supervisions] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
# Caution: We assume the subsampling factor is 4!
x, layer_outputs = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
if self.group_num != 0:
x = 0
for enum, alpha in enumerate(self.alpha):
x += self.sigmoid(alpha) * layer_outputs[(enum+1)*self.group_layer_num-1]
x = self.layer_norm(x/self.group_num)
# x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
# return x, lengths
return x, mask
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
""" """
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.