From 69679e023e147fe4d09d80a9e1b1b5c9b8b5f212 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 22 Jan 2022 16:04:17 +0800 Subject: [PATCH] Use CTC loss as auxiliary loss. --- .../ASR/transducer_stateless/model.py | 11 ++- .../ASR/transducer_stateless/train.py | 82 ++++++++++++++++++- 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 7aac290d9..1199e604e 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -74,7 +74,10 @@ class Transducer(nn.Module): A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. Returns: - Return the transducer loss. + Return a tuple containing: + - the transducer loss, a tensor containing only one entry + - encoder_out, a tensor of shape (N, num_frames, encoder_out_dim) + - encoder_out_lens, a tensor of shape (N,) """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -123,4 +126,8 @@ class Transducer(nn.Module): from_log_softmax=False, ) - return loss + return ( + loss, + encoder_out, + x_lens, + ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..5b27a2482 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -138,6 +138,15 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.25, + help="""If not zero, the total loss is: + (1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss + """, + ) + return parser @@ -206,6 +215,13 @@ def get_params() -> AttributeDict: "vgg_frontend": False, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k + # + # parameters for ctc_loss, used only when ctc_weight > 0 + "modified_ctc_topo": False, + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # "env_info": get_env_info(), } ) @@ -259,6 +275,17 @@ def get_transducer_model(params: AttributeDict): return model +def get_ctc_model(params: AttributeDict): + if params.ctc_weight > 0: + return nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(params.encoder_out_dim, params.vocab_size), + nn.LogSoftmax(dim=-1), + ) + else: + return None + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -379,11 +406,52 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + transducer_loss, encoder_out, encoder_out_lens = model( + x=feature, x_lens=feature_lens, y=y + ) + loss = transducer_loss + + if params.ctc_weight > 0: + ctc_model = ( + model.module.ctc if hasattr(model, "module") else model.ctc + ) + ctc_graph = k2.ctc_graph( + token_ids, modified=params.modified_ctc_topo, device=device + ) + # Note: We assume `encoder_out_lens` is sorted in descending order. + # If not, it will throw in k2.ctc_loss(). + supervision_segments = torch.stack( + [ + torch.arange(encoder_out.size(0), dtype=torch.int32), + torch.zeros(encoder_out.size(0), dtype=torch.int32), + encoder_out_lens.cpu(), + ], + dim=1, + ).to(torch.int32) + nnet_out = ctc_model(encoder_out) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_out, + supervision_segments, + allow_truncate=0, + ) + + # Note: transducer_loss should use the same reduction as ctc_loss + ctc_loss = k2.ctc_loss( + decoding_graph=ctc_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss = ( + 1 - params.ctc_weight + ) * transducer_loss + params.ctc_weight * ctc_loss assert loss.requires_grad == is_training @@ -392,6 +460,9 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["transducer_loss"] = transducer_loss.detach().cpu().item() + if params.ctc_weight > 0: + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info @@ -574,6 +645,11 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) + model.ctc = get_ctc_model(params) + if model.ctc is not None: + logging.info(f"Enable CTC loss with weight: {params.ctc_weight}") + else: + logging.info("Disable CTC loss") num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")