diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/model.py index 272d06c37..74dc99769 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/model.py @@ -69,6 +69,12 @@ class Transducer(nn.Module): self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5), + nn.LogSoftmax(dim=-1), + ) + def forward( self, x: torch.Tensor, @@ -80,7 +86,7 @@ class Transducer(nn.Module): warmup: float = 1.0, reduction: str = "sum", delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -113,8 +119,7 @@ class Transducer(nn.Module): See https://github.com/k2-fsa/k2/issues/955 and https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: - Returns: - Return the transducer loss. + Return a tuple containing simple loss, pruned loss, and ctc-output. Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -132,6 +137,9 @@ class Transducer(nn.Module): encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) assert torch.all(x_lens > 0) + # compute ctc log-probs + ctc_output = self.ctc_output(encoder_out) + # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) y_lens = row_splits[1:] - row_splits[:-1] @@ -204,4 +212,4 @@ class Transducer(nn.Module): reduction=reduction, ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, ctc_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py index 9bd7df401..de76561cc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4_ctc/train.py @@ -97,6 +97,7 @@ from icefall.utils import ( AttributeDict, MetricsTracker, display_and_save_batch, + encode_supervisions, setup_logger, str2bool, ) @@ -273,6 +274,13 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + parser.add_argument( "--seed", type=int, @@ -341,6 +349,16 @@ def get_parser(): https://arxiv.org/pdf/2211.00490.pdf for more details.""", ) + parser.add_argument( + "--ctc-delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay for CTC loss, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -412,6 +430,9 @@ def get_params() -> AttributeDict: "decoder_dim": 512, # parameters for joiner "joiner_dim": 512, + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), @@ -627,11 +648,11 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, ctc_output = model( x=feature, x_lens=feature_lens, y=y, @@ -640,7 +661,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", - delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + delay_penalty=params.delay_penalty if warmup >= 1.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) @@ -674,6 +695,38 @@ def compute_loss( ) loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + delay_penalty=params.ctc_delay_penalty if warmup >= 1.0 else 0.0, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -698,6 +751,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info