Initial conformer refactoring, not nearly done

This commit is contained in:
Daniel Povey 2021-08-22 11:47:26 +08:00
parent cbe5ee1111
commit 076a70b62d
2 changed files with 20 additions and 39 deletions

View File

@ -26,7 +26,6 @@ class Conformer(Transformer):
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
@ -42,10 +41,7 @@ class Conformer(Transformer):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -58,9 +54,6 @@ class Conformer(Transformer):
num_decoder_layers=num_decoder_layers,
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)

View File

@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
@ -18,7 +17,6 @@ class Transformer(nn.Module):
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
@ -26,9 +24,6 @@ class Transformer(nn.Module):
num_decoder_layers: int = 6,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
@ -54,16 +49,9 @@ class Transformer(nn.Module):
Dropout in encoder/decoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
mmi_loss:
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.num_classes = num_classes
@ -76,10 +64,10 @@ class Transformer(nn.Module):
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
#self.encoder_embed = [TODO...]
self.encoder_pos = PositionalEncoding(d_model, dropout)
@ -108,14 +96,7 @@ class Transformer(nn.Module):
)
if num_decoder_layers > 0:
if mmi_loss:
self.decoder_num_class = (
self.num_classes + 1
) # +1 for the sos/eos symbol
else:
self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_num_class = self.num_classes
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
@ -150,12 +131,22 @@ class Transformer(nn.Module):
self.decoder_criterion = None
def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
self,
src_symbols: torch.Tensor,
src_padding_mask: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is [N, T, C].
src_symbols:
The input symbols to be embedded (will actually have query positions
masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64.
I.e. shape (N, T)
src_padding_mask:
Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T),
and dtype=torch.bool which has True in positions to be masked in attention
layers and convolutions because they represent padding at the ends of
sequences.
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
@ -171,10 +162,7 @@ class Transformer(nn.Module):
memory_key_padding_mask for the decoder. Its shape is [N, T].
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)