From 09bbed327572a05a5181792f1d4d72994e4c83ba Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Feb 2022 19:10:21 +0800 Subject: [PATCH] Use CTC loss as auxiliary loss. See https://github.com/k2-fsa/icefall/pull/186 --- .../ASR/transducer_stateless/model.py | 11 ++- .../ASR/transducer_stateless/train.py | 79 ++++++++++++++++++- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 8281e1fb5..4e4f9d13d 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -79,7 +79,10 @@ class Transducer(nn.Module): modified_transducer_prob: The probability to use modified transducer loss. Returns: - Return the transducer loss. + Return a tuple containing: + - the transducer loss, a tensor containing only one entry + - encoder_out, a tensor of shape (N, num_frames, encoder_out_dim) + - encoder_out_lens, a tensor of shape (N,) """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape @@ -140,4 +143,8 @@ class Transducer(nn.Module): from_log_softmax=False, ) - return loss + return ( + loss, + encoder_out, + x_lens, + ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c0b1b3a42..0ceb523e7 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -157,6 +157,14 @@ def get_parser(): help="If enabled, apply random frame shift along the time axis", ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.25, + help="""If not zero, the total loss is: + (1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss + """, + ) return parser @@ -225,6 +233,13 @@ def get_params() -> AttributeDict: "vgg_frontend": False, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k + # + # parameters for ctc_loss, used only when ctc_weight > 0 + "modified_ctc_topo": False, + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # "env_info": get_env_info(), } ) @@ -278,6 +293,17 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: return model +def get_ctc_model(params: AttributeDict): + if params.ctc_weight > 0: + return nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(params.encoder_out_dim, params.vocab_size), + nn.LogSoftmax(dim=-1), + ) + else: + return None + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -398,16 +424,55 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) with torch.set_grad_enabled(is_training): - loss = model( + transducer_loss, encoder_out, encoder_out_lens = model( x=feature, x_lens=feature_lens, y=y, modified_transducer_prob=params.modified_transducer_prob, ) + loss = transducer_loss + + if params.ctc_weight > 0: + ctc_model = ( + model.module.ctc if hasattr(model, "module") else model.ctc + ) + ctc_graph = k2.ctc_graph( + token_ids, modified=params.modified_ctc_topo, device=device + ) + # Note: We assume `encoder_out_lens` is sorted in descending order. + # If not, it will throw in k2.ctc_loss(). + supervision_segments = torch.stack( + [ + torch.arange(encoder_out.size(0), dtype=torch.int32), + torch.zeros(encoder_out.size(0), dtype=torch.int32), + encoder_out_lens.cpu(), + ], + dim=1, + ).to(torch.int32) + nnet_out = ctc_model(encoder_out) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_out, + supervision_segments, + allow_truncate=0, + ) + + # Note: transducer_loss should use the same reduction as ctc_loss + ctc_loss = k2.ctc_loss( + decoding_graph=ctc_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss = ( + 1 - params.ctc_weight + ) * transducer_loss + params.ctc_weight * ctc_loss assert loss.requires_grad == is_training @@ -416,6 +481,9 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() + info["transducer_loss"] = transducer_loss.detach().cpu().item() + if params.ctc_weight > 0: + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info @@ -598,6 +666,11 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) + model.ctc = get_ctc_model(params) + if model.ctc is not None: + logging.info(f"Enable CTC loss with weight: {params.ctc_weight}") + else: + logging.info("Disable CTC loss") num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")