diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/model.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/model.py index 8141f9a83..109b59847 100644 --- a/egs/librispeech/ASR/transducer_stateless_aux_kl/model.py +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/model.py @@ -15,11 +15,12 @@ # limitations under the License. import random -from typing import Optional +from typing import Optional, Tuple import k2 import torch import torch.nn as nn +import torch.nn.functional as F from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -37,6 +38,7 @@ class Transducer(nn.Module): joiner: nn.Module, decoder_giga: Optional[nn.Module] = None, joiner_giga: Optional[nn.Module] = None, + aux_module: Optional[nn.Module] = None, ): """ Args: @@ -57,6 +59,8 @@ class Transducer(nn.Module): The decoder for the GigaSpeech dataset. joiner_giga: The joiner for the GigaSpeech dataset. + aux_module: + Optional. The auxiliary branch for computing auxiliary losses. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -73,6 +77,8 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga + self.aux_module = aux_module + def forward( self, x: torch.Tensor, @@ -80,7 +86,7 @@ class Transducer(nn.Module): y: k2.RaggedTensor, libri: bool = True, modified_transducer_prob: float = 0.0, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -97,7 +103,10 @@ class Transducer(nn.Module): modified_transducer_prob: The probability to use modified transducer loss. Returns: - Return the transducer loss. + Return a tuple of 3 scalar tensors containing: + - transducer loss + - auxiliary transducer loss + - KL loss """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -105,7 +114,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) + encoder_out, x_lens, aux_input = self.encoder(x, x_lens) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network @@ -154,15 +163,51 @@ class Transducer(nn.Module): else: one_sym_per_frame = False - loss = optimized_transducer.transducer_loss( - logits=logits, + log_probs = F.log_softmax(logits, dim=-1) + transducer_loss = optimized_transducer.transducer_loss( + logits=log_probs, targets=y_padded, logit_lengths=x_lens, target_lengths=y_lens, blank=blank_id, reduction="sum", one_sym_per_frame=one_sym_per_frame, - from_log_softmax=False, + from_log_softmax=True, ) - return loss + aux_output = self.aux_module(aux_input) + + # Now process the auxiliary branch + aux_logits = joiner( + encoder_out=aux_output, + decoder_out=decoder_out, + encoder_out_len=x_lens, + decoder_out_len=y_lens + 1, + ) + aux_log_probs = F.log_softmax(aux_logits, dim=-1) + + aux_transducer_loss = optimized_transducer.transducer_loss( + logits=aux_log_probs, + targets=y_padded, + logit_lengths=x_lens, + target_lengths=y_lens, + blank=blank_id, + reduction="sum", + one_sym_per_frame=one_sym_per_frame, + from_log_softmax=True, + ) + kl_loss_1 = F.kl_div( + input=log_probs, + target=aux_log_probs, + reduction="sum", + log_target=True, + ) + kl_loss_2 = F.kl_div( + input=aux_log_probs, + target=log_probs, + reduction="sum", + log_target=True, + ) + kl_loss = (kl_loss_1 + kl_loss_2) * 0.5 + + return transducer_loss, aux_transducer_loss, kl_loss diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py index ae91e76fd..d56beed9e 100755 --- a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py @@ -168,6 +168,13 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) + parser.add_argument( + "--lambda-aux", + type=float, + default=0.3, + help="The scale applied to the auxiliary losses", + ) + return parser @@ -280,6 +287,14 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_aux_model(params: AttributeDict) -> nn.Module: + return nn.Sequential( + nn.Linear(params.attention_dim, params.encoder_out_dim), + nn.ReLU(inplace=True), + nn.Linear(params.encoder_out_dim, params.encoder_out_dim), + ) + + def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) @@ -289,12 +304,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_giga = get_decoder_model(params) joiner_giga = get_joiner_model(params) + aux_module = get_aux_model(params) + model = Transducer( encoder=encoder, decoder=decoder, joiner=joiner, decoder_giga=decoder_giga, joiner_giga=joiner_giga, + aux_module=aux_module, ) return model @@ -436,7 +454,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model( + transducer_loss, aux_transducer_loss, kl_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -444,15 +462,25 @@ def compute_loss( modified_transducer_prob=params.modified_transducer_prob, ) - assert loss.requires_grad == is_training + aux_loss = aux_transducer_loss + kl_loss + + assert transducer_loss.requires_grad == is_training + assert aux_transducer_loss.requires_grad == is_training + assert kl_loss.requires_grad == is_training info = MetricsTracker() info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() + info["tot_loss"] = ( + (transducer_loss + params.lambda_aux * aux_loss).detach().cpu().item() + ) - return loss, info + info["transducer_loss"] = transducer_loss.detach().cpu().item() + info["aux_transducer_loss"] = aux_transducer_loss.detach().cpu().item() + info["kl_loss"] = kl_loss.detach().cpu().item() + + return transducer_loss, aux_loss, info def compute_validation_loss( @@ -468,7 +496,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( + transduer_loss, aux_loss, loss_info = compute_loss( params=params, model=model, sp=sp, @@ -481,7 +509,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss["loss"] / tot_loss["frames"] + loss_value = tot_loss["tot_loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value @@ -557,7 +585,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - loss, loss_info = compute_loss( + transducer_loss, aux_loss, loss_info = compute_loss( params=params, model=model, sp=sp, @@ -581,7 +609,9 @@ def train_one_epoch( # in the batch and there is no normalization to it so far. optimizer.zero_grad() - loss.backward() + + (transducer_loss + aux_loss * params.lambda_aux).backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() @@ -849,14 +879,16 @@ def scan_pessimistic_batches_for_oom( batch = train_dl.dataset[cuts] try: optimizer.zero_grad() - loss, _ = compute_loss( + transducer_loss, aux_loss, _ = compute_loss( params=params, model=model, sp=sp, batch=batch, is_training=True, ) - loss.backward() + + (transducer_loss + aux_loss * params.lambda_aux).backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() except RuntimeError as e: