mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Initial conformer refactoring, not nearly done
This commit is contained in:
parent
cbe5ee1111
commit
076a70b62d
@ -26,7 +26,6 @@ class Conformer(Transformer):
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -42,10 +41,7 @@ class Conformer(Transformer):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
is_espnet_structure: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -58,9 +54,6 @@ class Conformer(Transformer):
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
mmi_loss=mmi_loss,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
@ -18,7 +17,6 @@ class Transformer(nn.Module):
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
@ -26,9 +24,6 @@ class Transformer(nn.Module):
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -54,16 +49,9 @@ class Transformer(nn.Module):
|
||||
Dropout in encoder/decoder.
|
||||
normalize_before:
|
||||
If True, use pre-layer norm; False to use post-layer norm.
|
||||
vgg_frontend:
|
||||
True to use vgg style frontend for subsampling.
|
||||
mmi_loss:
|
||||
use_feat_batchnorm:
|
||||
True to use batchnorm for the input layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_feat_batchnorm = use_feat_batchnorm
|
||||
if use_feat_batchnorm:
|
||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||
|
||||
|
||||
self.num_features = num_features
|
||||
self.num_classes = num_classes
|
||||
@ -76,10 +64,10 @@ class Transformer(nn.Module):
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_classes -> d_model
|
||||
if vgg_frontend:
|
||||
self.encoder_embed = VggSubsampling(num_features, d_model)
|
||||
else:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
|
||||
#self.encoder_embed = [TODO...]
|
||||
|
||||
|
||||
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
@ -108,14 +96,7 @@ class Transformer(nn.Module):
|
||||
)
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
if mmi_loss:
|
||||
self.decoder_num_class = (
|
||||
self.num_classes + 1
|
||||
) # +1 for the sos/eos symbol
|
||||
else:
|
||||
self.decoder_num_class = (
|
||||
self.num_classes
|
||||
) # bpe model already has sos/eos symbol
|
||||
self.decoder_num_class = self.num_classes
|
||||
|
||||
self.decoder_embed = nn.Embedding(
|
||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
||||
@ -150,12 +131,22 @@ class Transformer(nn.Module):
|
||||
self.decoder_criterion = None
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
||||
self,
|
||||
src_symbols: torch.Tensor,
|
||||
src_padding_mask: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is [N, T, C].
|
||||
src_symbols:
|
||||
The input symbols to be embedded (will actually have query positions
|
||||
masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64.
|
||||
I.e. shape (N, T)
|
||||
src_padding_mask:
|
||||
Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T),
|
||||
and dtype=torch.bool which has True in positions to be masked in attention
|
||||
layers and convolutions because they represent padding at the ends of
|
||||
sequences.
|
||||
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -171,10 +162,7 @@ class Transformer(nn.Module):
|
||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user