diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index d8b184752..03a47927f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,9 +32,10 @@ class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Number of output dimension + output_dim (int): Model output dimension. If not equal to the encoder dimension, + we will project to the output. subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension + d_model (int): attention dimension, also the output dimension nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -42,7 +43,6 @@ class Conformer(EncoderInterface): cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ - def __init__( self, num_features: int, @@ -59,7 +59,6 @@ class Conformer(EncoderInterface): super(Conformer, self).__init__() self.num_features = num_features - self.output_dim = output_dim self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") @@ -83,7 +82,11 @@ class Conformer(EncoderInterface): self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.encoder_output_layer = ScaledLinear(d_model, output_dim) + if output_dim == d_model: + self.encoder_output_layer = Identity() + else: + self.encoder_output_layer = ScaledLinear(d_model, output_dim, + initial_speed=0.5) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -101,9 +104,9 @@ class Conformer(EncoderInterface): to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. + - embeddings: its shape is (batch_size, output_seq_len, d_model) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. """ x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) @@ -117,10 +120,10 @@ class Conformer(EncoderInterface): x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths + return x, lengths class ConformerEncoderLayer(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 3470b647f..a442feeea 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -68,6 +68,7 @@ class Decoder(nn.Module): initial_speed=initial_speed ) self.blank_id = blank_id + self.output_linear = ScaledLinear(embedding_dim, embedding_dim) assert context_size >= 1, context_size self.context_size = context_size @@ -81,8 +82,6 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) - self.output_linear = ScaledLinear(embedding_dim, vocab_size) - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 61bfe8186..973a89bfe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -20,11 +20,10 @@ import torch.nn.functional as F from scaling import ScaledLinear class Joiner(nn.Module): - def __init__(self, input_dim: int, inner_dim: int, output_dim: int): + def __init__(self, input_dim: int, output_dim: int): super().__init__() - self.inner_linear = ScaledLinear(input_dim, inner_dim) - self.output_linear = ScaledLinear(inner_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -43,8 +42,6 @@ class Joiner(nn.Module): logit = encoder_out + decoder_out - logit = self.inner_linear(torch.tanh(logit)) - - output = self.output_linear(F.relu(logit)) + logit = self.output_linear(torch.tanh(logit)) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index faaebc477..2f102bdf8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -19,6 +19,7 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos @@ -33,6 +34,8 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + embedding_dim: int, + vocab_size: int ): """ Args: @@ -58,6 +61,10 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner + # could perhaps separate this into 2 linear projections, one + # for lm and one for am. + self.simple_joiner = nn.Linear(embedding_dim, vocab_size) + def forward( self, x: torch.Tensor, @@ -133,8 +140,8 @@ class Transducer(nn.Module): boundary[:, 3] = x_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=decoder_out, - am=encoder_out, + lm=self.simple_joiner(decoder_out), + am=self.simple_joiner(encoder_out), symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 8d5142937..649234f0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.vocab_size, + output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, + input_dim=params.embedding_dim, output_dim=params.vocab_size, ) return joiner @@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, + embedding_dim=params.embedding_dim, + vocab_size=params.vocab_size, ) return model