diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index a00664a99..3014055b4 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -26,7 +26,6 @@ class Conformer(Transformer): dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. """ def __init__( @@ -42,10 +41,7 @@ class Conformer(Transformer): dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, - vgg_frontend: bool = False, is_espnet_structure: bool = False, - mmi_loss: bool = True, - use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -58,9 +54,6 @@ class Conformer(Transformer): num_decoder_layers=num_decoder_layers, dropout=dropout, normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - mmi_loss=mmi_loss, - use_feat_batchnorm=use_feat_batchnorm, ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py index 51c77b220..707eacd1b 100644 --- a/egs/librispeech/ASR/conformer_lm/transformer.py +++ b/egs/librispeech/ASR/conformer_lm/transformer.py @@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from subsampling import Conv2dSubsampling, VggSubsampling from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. @@ -18,7 +17,6 @@ class Transformer(nn.Module): self, num_features: int, num_classes: int, - subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, @@ -26,9 +24,6 @@ class Transformer(nn.Module): num_decoder_layers: int = 6, dropout: float = 0.1, normalize_before: bool = True, - vgg_frontend: bool = False, - mmi_loss: bool = True, - use_feat_batchnorm: bool = False, ) -> None: """ Args: @@ -54,16 +49,9 @@ class Transformer(nn.Module): Dropout in encoder/decoder. normalize_before: If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - mmi_loss: - use_feat_batchnorm: - True to use batchnorm for the input layer. """ super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) + self.num_features = num_features self.num_classes = num_classes @@ -76,10 +64,10 @@ class Transformer(nn.Module): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + + #self.encoder_embed = [TODO...] + self.encoder_pos = PositionalEncoding(d_model, dropout) @@ -108,14 +96,7 @@ class Transformer(nn.Module): ) if num_decoder_layers > 0: - if mmi_loss: - self.decoder_num_class = ( - self.num_classes + 1 - ) # +1 for the sos/eos symbol - else: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol + self.decoder_num_class = self.num_classes self.decoder_embed = nn.Embedding( num_embeddings=self.decoder_num_class, embedding_dim=d_model @@ -150,12 +131,22 @@ class Transformer(nn.Module): self.decoder_criterion = None def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None + self, + src_symbols: torch.Tensor, + src_padding_mask: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: - The input tensor. Its shape is [N, T, C]. + src_symbols: + The input symbols to be embedded (will actually have query positions + masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. + I.e. shape (N, T) + src_padding_mask: + Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), + and dtype=torch.bool which has True in positions to be masked in attention + layers and convolutions because they represent padding at the ends of + sequences. + supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -171,10 +162,7 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is [N, T]. It is None if `supervision` is None. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision )