From 8be10d3d6c39dbb51f932c8cebea7cb67055ed92 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:03:21 +0800 Subject: [PATCH] First draft of model rework --- .../pruned_transducer_stateless2/conformer.py | 11 +------ .../pruned_transducer_stateless2/decoder.py | 17 +++++----- .../pruned_transducer_stateless2/joiner.py | 16 ++++++--- .../ASR/pruned_transducer_stateless2/model.py | 33 ++++++++----------- .../ASR/pruned_transducer_stateless2/train.py | 22 ++++++++----- 5 files changed, 49 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index c7ce3bec2..0deb960ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,8 +32,6 @@ class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Model output dimension. If not equal to the encoder dimension, - we will project to the output. subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head @@ -47,7 +45,6 @@ class Conformer(EncoderInterface): def __init__( self, num_features: int, - output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, @@ -83,11 +80,6 @@ class Conformer(EncoderInterface): ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - if output_dim == d_model: - self.encoder_output_layer = nn.Identity() - else: - self.encoder_output_layer = ScaledLinear(d_model, output_dim, - initial_speed=0.5) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -123,7 +115,6 @@ class Conformer(EncoderInterface): x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths @@ -1116,7 +1107,7 @@ class Noam(object): if __name__ == '__main__': feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index a442feeea..25a36223d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -46,8 +46,8 @@ class Decoder(nn.Module): Args: vocab_size: Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. + decoder_dim: + Dimension of the input embedding, and of the decoder output. blank_id: The ID of the blank symbol. context_size: @@ -63,23 +63,22 @@ class Decoder(nn.Module): initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, - embedding_dim=embedding_dim, + embedding_dim=decoder_dim, padding_idx=blank_id, initial_speed=initial_speed ) self.blank_id = blank_id - self.output_linear = ScaledLinear(embedding_dim, embedding_dim) assert context_size >= 1, context_size self.context_size = context_size self.vocab_size = vocab_size if context_size > 1: self.conv = ScaledConv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, + in_channels=decoder_dim, + out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=embedding_dim, + groups=decoder_dim, bias=False, ) @@ -92,7 +91,7 @@ class Decoder(nn.Module): True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: - Return a tensor of shape (N, U, embedding_dim). + Return a tensor of shape (N, U, decoder_dim). """ y = y.to(torch.int64) embedding_out = self.embedding(y) @@ -108,5 +107,5 @@ class Decoder(nn.Module): assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = self.output_linear(F.relu(embedding_out)) + embedding_out = F.relu(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index b9c465398..64752b9a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -20,11 +20,19 @@ import torch.nn.functional as F from scaling import ScaledLinear class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): + def __init__(self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int): super().__init__() - self.output_linear = ScaledLinear(input_dim, output_dim, - initial_speed=0.5) + # We don't bother giving the 'initial_speed' arg to the decoder + # submodules, because it does not affect the initial convergence of the + # system (only the simple joiner is involved in that). + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) + self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -41,7 +49,7 @@ class Joiner(nn.Module): assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape == decoder_out.shape - logit = encoder_out + decoder_out + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 0355c4531..5d4c32ac4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -34,23 +34,25 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, - embedding_dim: int, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, vocab_size: int ): """ Args: encoder: It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and `logit_lens` of shape (N,). decoder: It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain + is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its + output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() @@ -61,17 +63,10 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, initial_speed=0.5) - with torch.no_grad(): - # Initialize the two projections to be the same; this will be - # convenient for the real joiner, which adds the endcoder - # (acoustic-model/am) and decoder (language-model/lm) embeddings - self.simple_lm_proj.weight[:] = self.simple_am_proj.weight - self.simple_lm_proj.bias[:] = self.simple_am_proj.bias - def forward( self, @@ -133,7 +128,7 @@ class Transducer(nn.Module): # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - # decoder_out: [B, S + 1, C] + # decoder_out: [B, S + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded) # Note: y does not start with SOS @@ -167,13 +162,13 @@ class Transducer(nn.Module): s_range=prune_range, ) - # am_pruned : [B, T, prune_range, C] - # lm_pruned : [B, T, prune_range, C] + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( am=encoder_out, lm=decoder_out, ranges=ranges ) - # logits : [B, T, prune_range, C] + # logits : [B, T, prune_range, vocab_size] logits = self.joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c716d457a..a027a5adc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -268,7 +268,7 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - attention_dim: Hidden dim for multi-head attention model. + - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. @@ -287,12 +287,14 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, + "encoder_dim": 512, "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, # parameters for decoder - "embedding_dim": 512, + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 4000, # arg given to model, not for lrate @@ -309,7 +311,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_features=params.feature_dim, output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, + d_model=params.encoder_dim, nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, @@ -329,8 +331,10 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.embedding_dim, - output_dim=params.vocab_size, + encoder_dim=params.encoder_dim + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return joiner @@ -344,7 +348,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - embedding_dim=params.embedding_dim, + encoder_dim=params.encoder_dim + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return model @@ -748,7 +754,7 @@ def run(rank, world_size, args): optimizer = Noam( model.parameters(), - model_size=params.attention_dim, + model_size=params.encoder_dim, factor=params.lr_factor, warm_step=params.warm_step, )