diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 3d4818509..70052a474 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -297,7 +297,7 @@ def beam_search( current_encoder_out, decoder_out.unsqueeze(1) ) - # TODO(fangjun): Cache the blank posterior + # TODO(fangjun): Scale the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index af3292edf..aec745a9c 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -84,24 +84,30 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - embeding_out = self.embedding(y) + embedding_out = self.embedding(y) if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) + embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: + # If the input is [sos, a, b, c, d] and output is + # [a, b, c, d, eos], padding left and using kernel-size 2, + # it uses left context. + # If the input is [a, b, c, d, eos] and output is + # [sos, a, b, c, d], padding right and using kernel-size 2, + # it uses right context. if self.backward: assert self.context_size == 2 - embeding_out = F.pad( - embeding_out, pad=(0, self.context_size - 1) + embedding_out = F.pad( + embedding_out, pad=(0, self.context_size - 1) ) else: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embeding_out.size(-1) == self.context_size + assert embedding_out.size(-1) == self.context_size assert self.backward is False - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 865fe903f..823bd8fca 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn from encoder_interface import EncoderInterface -from icefall.utils import add_sos +from icefall.utils import add_eos, add_sos class Transducer(nn.Module): @@ -50,10 +50,30 @@ class Transducer(nn.Module): It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain one attribute: `blank_id`. + backward_decoder: + Almost the same as decoder, except that it uses right context and + the decoder uses left context. joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + backward_joiner: + The same as joiner, it intends for backward_decoder. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -110,7 +130,6 @@ class Transducer(nn.Module): # Note: y does not start with SOS y_padded = y.pad(mode="constant", padding_value=0) - y_padded = y_padded.to(torch.int64) boundary = torch.zeros( (x.size(0), 4), dtype=torch.int64, device=x.device ) @@ -121,7 +140,7 @@ class Transducer(nn.Module): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( decoder_out, encoder_out, - y_padded, + y_padded.to(torch.int64), blank_id, lm_only_scale=self.lm_scale, am_only_scale=self.am_scale, @@ -138,13 +157,15 @@ class Transducer(nn.Module): ) logits = self.joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( - logits, y_padded, ranges, blank_id, boundary + logits, y_padded.to(torch.int64), ranges, blank_id, boundary ) + eos_y = add_eos(y, eos_id=blank_id) + eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id) # backward loss assert self.backward_decoder is not None assert self.backward_joiner is not None - backward_decoder_out = self.backward_decoder(sos_y_padded) + backward_decoder_out = self.backward_decoder(eos_y_padded) backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning( encoder_out, backward_decoder_out, ranges ) @@ -152,7 +173,11 @@ class Transducer(nn.Module): backward_am_pruned, backward_lm_pruned ) backward_pruned_loss = k2.rnnt_loss_pruned( - backward_logits, y_padded, ranges, blank_id, boundary + backward_logits, + sos_y_padded.to(torch.int64), + ranges, + blank_id, + boundary, ) return ( diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index d35d3c66a..8849a098f 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -227,7 +227,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 - "log_diagnostics": False, + "log_diagnostics": True, # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -246,7 +246,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -261,7 +261,9 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict, backward: bool = False): +def get_decoder_model( + params: AttributeDict, backward: bool = False +) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -272,7 +274,7 @@ def get_decoder_model(params: AttributeDict, backward: bool = False): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -280,7 +282,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params)