diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index dca084477..af3292edf 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -38,6 +38,7 @@ class Decoder(nn.Module): embedding_dim: int, blank_id: int, context_size: int, + backward: bool = False, ): """ Args: @@ -61,6 +62,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size + self.backward = backward if context_size > 1: self.conv = nn.Conv1d( in_channels=embedding_dim, @@ -86,13 +88,20 @@ class Decoder(nn.Module): if self.context_size > 1: embeding_out = embeding_out.permute(0, 2, 1) if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) - ) + if self.backward: + assert self.context_size == 2 + embeding_out = F.pad( + embeding_out, pad=(0, self.context_size - 1) + ) + else: + embeding_out = F.pad( + embeding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output assert embeding_out.size(-1) == self.context_size + assert self.backward is False embeding_out = self.conv(embeding_out) embeding_out = embeding_out.permute(0, 2, 1) return embeding_out diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 837d76ddc..865fe903f 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -32,8 +32,12 @@ class Transducer(nn.Module): self, encoder: EncoderInterface, decoder: nn.Module, + backward_decoder: nn.Module, joiner: nn.Module, + backward_joiner: nn.Module, prune_range: int = 3, + lm_scale: float = 0.0, + am_scale: float = 0.0, ): """ Args: @@ -57,8 +61,12 @@ class Transducer(nn.Module): self.encoder = encoder self.decoder = decoder + self.backward_decoder = backward_decoder self.joiner = joiner + self.backward_joiner = backward_joiner self.prune_range = prune_range + self.lm_scale = lm_scale + self.am_scale = am_scale def forward( self, @@ -109,22 +117,46 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( - decoder_out, encoder_out, y_padded, blank_id, boundary, True + # calculate prune ranges + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + decoder_out, + encoder_out, + y_padded, + blank_id, + lm_only_scale=self.lm_scale, + am_only_scale=self.am_scale, + boundary=boundary, + return_grad=True, ) - ranges = k2.get_rnnt_prune_ranges( px_grad, py_grad, boundary, self.prune_range ) - am_pruning, lm_pruning = k2.do_rnnt_pruning( + # forward loss + am_pruned, lm_pruned = k2.do_rnnt_pruning( encoder_out, decoder_out, ranges ) - - logits = self.joiner(am_pruning, lm_pruning) - - pruning_loss = k2.rnnt_loss_pruned( + logits = self.joiner(am_pruned, lm_pruned) + pruned_loss = k2.rnnt_loss_pruned( logits, y_padded, ranges, blank_id, boundary ) - return (-torch.sum(simple_loss), -torch.sum(pruning_loss)) + # backward loss + assert self.backward_decoder is not None + assert self.backward_joiner is not None + backward_decoder_out = self.backward_decoder(sos_y_padded) + backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning( + encoder_out, backward_decoder_out, ranges + ) + backward_logits = self.backward_joiner( + backward_am_pruned, backward_lm_pruned + ) + backward_pruned_loss = k2.rnnt_loss_pruned( + backward_logits, y_padded, ranges, blank_id, boundary + ) + + return ( + -torch.sum(simple_loss), + -torch.sum(pruned_loss), + -torch.sum(backward_pruned_loss), + ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 3b92a5ee4..d35d3c66a 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -154,6 +154,22 @@ def get_parser(): "we are using to compute the loss", ) + parser.add_argument( + "--lm-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + return parser @@ -245,12 +261,13 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict, backward: bool = False): decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, blank_id=params.blank_id, context_size=params.context_size, + backward=backward, ) return decoder @@ -267,11 +284,18 @@ def get_transducer_model(params: AttributeDict): encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) + backward_decoder = get_decoder_model(params, backward=True) + backward_joiner = get_joiner_model(params) model = Transducer( encoder=encoder, decoder=decoder, + backward_decoder=backward_decoder, joiner=joiner, + backward_joiner=backward_joiner, + prune_range=params.prune_range, + lm_scale=params.lm_scale, + am_scale=params.am_scale, ) return model @@ -400,8 +424,10 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) - loss = simple_loss + pruned_loss + simple_loss, pruned_loss, backward_pruned_loss = model( + x=feature, x_lens=feature_lens, y=y + ) + loss = simple_loss + pruned_loss + backward_pruned_loss assert loss.requires_grad == is_training @@ -412,6 +438,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["backward_pruned_loss"] = backward_pruned_loss.detach().cpu().item() return loss, info