mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Initial conformer refactoring, not nearly done
This commit is contained in:
parent
cbe5ee1111
commit
076a70b62d
@ -26,7 +26,6 @@ class Conformer(Transformer):
|
|||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
normalize_before (bool): whether to use layer_norm before the first block.
|
normalize_before (bool): whether to use layer_norm before the first block.
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -42,10 +41,7 @@ class Conformer(Transformer):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
|
||||||
is_espnet_structure: bool = False,
|
is_espnet_structure: bool = False,
|
||||||
mmi_loss: bool = True,
|
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -58,9 +54,6 @@ class Conformer(Transformer):
|
|||||||
num_decoder_layers=num_decoder_layers,
|
num_decoder_layers=num_decoder_layers,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
|
||||||
mmi_loss=mmi_loss,
|
|
||||||
use_feat_batchnorm=use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
@ -6,7 +6,6 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
@ -18,7 +17,6 @@ class Transformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
num_classes: int,
|
num_classes: int,
|
||||||
subsampling_factor: int = 4,
|
|
||||||
d_model: int = 256,
|
d_model: int = 256,
|
||||||
nhead: int = 4,
|
nhead: int = 4,
|
||||||
dim_feedforward: int = 2048,
|
dim_feedforward: int = 2048,
|
||||||
@ -26,9 +24,6 @@ class Transformer(nn.Module):
|
|||||||
num_decoder_layers: int = 6,
|
num_decoder_layers: int = 6,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
|
||||||
mmi_loss: bool = True,
|
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -54,16 +49,9 @@ class Transformer(nn.Module):
|
|||||||
Dropout in encoder/decoder.
|
Dropout in encoder/decoder.
|
||||||
normalize_before:
|
normalize_before:
|
||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
|
||||||
True to use vgg style frontend for subsampling.
|
|
||||||
mmi_loss:
|
|
||||||
use_feat_batchnorm:
|
|
||||||
True to use batchnorm for the input layer.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_feat_batchnorm = use_feat_batchnorm
|
|
||||||
if use_feat_batchnorm:
|
|
||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -76,10 +64,10 @@ class Transformer(nn.Module):
|
|||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_classes -> d_model
|
# (2) embedding: num_classes -> d_model
|
||||||
if vgg_frontend:
|
|
||||||
self.encoder_embed = VggSubsampling(num_features, d_model)
|
|
||||||
else:
|
#self.encoder_embed = [TODO...]
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
|
||||||
|
|
||||||
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
@ -108,14 +96,7 @@ class Transformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if num_decoder_layers > 0:
|
if num_decoder_layers > 0:
|
||||||
if mmi_loss:
|
self.decoder_num_class = self.num_classes
|
||||||
self.decoder_num_class = (
|
|
||||||
self.num_classes + 1
|
|
||||||
) # +1 for the sos/eos symbol
|
|
||||||
else:
|
|
||||||
self.decoder_num_class = (
|
|
||||||
self.num_classes
|
|
||||||
) # bpe model already has sos/eos symbol
|
|
||||||
|
|
||||||
self.decoder_embed = nn.Embedding(
|
self.decoder_embed = nn.Embedding(
|
||||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
||||||
@ -150,12 +131,22 @@ class Transformer(nn.Module):
|
|||||||
self.decoder_criterion = None
|
self.decoder_criterion = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
self,
|
||||||
|
src_symbols: torch.Tensor,
|
||||||
|
src_padding_mask: torch.Tensor = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
src_symbols:
|
||||||
The input tensor. Its shape is [N, T, C].
|
The input symbols to be embedded (will actually have query positions
|
||||||
|
masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64.
|
||||||
|
I.e. shape (N, T)
|
||||||
|
src_padding_mask:
|
||||||
|
Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T),
|
||||||
|
and dtype=torch.bool which has True in positions to be masked in attention
|
||||||
|
layers and convolutions because they represent padding at the ends of
|
||||||
|
sequences.
|
||||||
|
|
||||||
supervision:
|
supervision:
|
||||||
Supervision in lhotse format.
|
Supervision in lhotse format.
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
@ -171,10 +162,7 @@ class Transformer(nn.Module):
|
|||||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
||||||
It is None if `supervision` is None.
|
It is None if `supervision` is None.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
x, supervision
|
x, supervision
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user