diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index deebb2a75..61b53067a 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,11 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import k2 import torch import torch.nn as nn +import torch.nn.functional as F from encoder_interface import EncoderInterface from scaling import ScaledLinear @@ -111,9 +112,8 @@ class AsrModel(nn.Module): if use_ctc: # Modules for CTC head self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), + nn.Dropout(p=0.1), # TODO: test removing this nn.Linear(encoder_dim, vocab_size), - nn.LogSoftmax(dim=-1), ) self.use_attention_decoder = use_attention_decoder @@ -159,71 +159,82 @@ class AsrModel(nn.Module): encoder_out_lens: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - targets: - Target Tensor of shape (sum(target_lengths)). The targets are assumed - to be un-padded and concatenated within 1 dimension. - """ - # Compute CTC log-prob - ctc_output = self.ctc_output(encoder_out) # (N, T, C) - - ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets.cpu(), - input_lengths=encoder_out_lens.cpu(), - target_lengths=target_lengths.cpu(), - reduction="sum", - ) - return ctc_loss - - def forward_cr_ctc( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - targets: torch.Tensor, - target_lengths: torch.Tensor, + use_consistency_reg: bool = False, + use_smooth_reg: bool = False, + smooth_kernel: List[float] = [0.25, 0.5, 0.25], + eps: float = 1e-6, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute CTC loss with consistency regularization loss. Args: encoder_out: - Encoder output, of shape (2 * N, T, C). + Encoder output, of shape (N or 2 * N, T, C). encoder_out_lens: - Encoder output lengths, of shape (2 * N,). + Encoder output lengths, of shape (N or 2 * N,). targets: Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension. + use_consistency_reg: + Whether use consistency regularization. + use_smooth_reg: + Whether use smooth regularization. """ + ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + + if not use_smooth_reg: + ctc_log_probs = F.log_softmax(ctc_output, dim=-1) + else: + ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss + ctc_log_probs = (ctc_probs + eps).log() + # Compute CTC loss - ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C) targets=targets.cpu(), input_lengths=encoder_out_lens.cpu(), target_lengths=target_lengths.cpu(), reduction="sum", ) - # Compute consistency regularization loss - exchanged_targets = ctc_output.detach().chunk(2, dim=0) - exchanged_targets = torch.cat( - [exchanged_targets[1], exchanged_targets[0]], dim=0 - ) # exchange: [x1, x2] -> [x2, x1] - cr_loss = nn.functional.kl_div( - input=ctc_output, - target=exchanged_targets, - reduction="none", - log_target=True, - ) # (2 * N, T, C) - length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) - cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + if use_consistency_reg: + assert ctc_log_probs.shape[0] % 2 == 0 + # Compute cr_loss + exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_log_probs, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + else: + cr_loss = torch.empty(0) - return ctc_loss, cr_loss + if use_smooth_reg: + # Hard code the kernel here, could try other values + assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel + smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype, + device=ctc_probs.device, requires_grad=False) + smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3) + # Now kernel: (C, 1, 3) + smoothed_ctc_probs = F.conv1d( + ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T) + weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1] + ).permute(0, 2, 1) # (N or 2 * N, T - 2, C) + sr_loss = nn.functional.kl_div( + input=ctc_log_probs[:, 1:-1], + target=(smoothed_ctc_probs + eps).log(), + reduction="none", + log_target=True, + ) # (N, T - 1 , C) + sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum() + else: + sr_loss = torch.empty(0) + + return ctc_loss, cr_loss, sr_loss def forward_transducer( self, @@ -341,6 +352,7 @@ class AsrModel(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, use_cr_ctc: bool = False, + use_sr_ctc: bool = False, use_spec_aug: bool = False, spec_augment: Optional[SpecAugment] = None, supervision_segments: Optional[torch.Tensor] = None, @@ -367,6 +379,8 @@ class AsrModel(nn.Module): part use_cr_ctc: Whether use consistency-regularized CTC. + use_sr_ctc: + Whether use smooth-regularized CTC. use_spec_aug: Whether apply spec-augment manually, used only if use_cr_ctc is True. spec_augment: @@ -445,26 +459,23 @@ class AsrModel(nn.Module): if self.use_ctc: # Compute CTC loss targets = y.values - if not use_cr_ctc: - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) - cr_loss = torch.empty(0) - else: - ctc_loss, cr_loss = self.forward_cr_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + ctc_loss, cr_loss, sr_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + use_consistency_reg=use_cr_ctc, + use_smooth_reg=use_sr_ctc, + ) + if use_cr_ctc: + # We duplicate the batch when use_cr_ctc is True ctc_loss = ctc_loss * 0.5 cr_loss = cr_loss * 0.5 + sr_loss = sr_loss * 0.5 else: ctc_loss = torch.empty(0) cr_loss = torch.empty(0) + sr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -478,4 +489,4 @@ class AsrModel(nn.Module): else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3a8995c81..ba1332300 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -312,6 +312,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use consistency-regularized CTC.", ) + parser.add_argument( + "--use-sr-ctc", + type=str2bool, + default=False, + help="If True, use smooth-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -464,6 +471,13 @@ def get_parser(): help="Scale for consistency-regularization loss.", ) + parser.add_argument( + "--sr-loss-scale", + type=float, + default=0.2, + help="Scale for smooth-regularization loss.", + ) + parser.add_argument( "--time-mask-ratio", type=float, @@ -916,6 +930,7 @@ def compute_loss( y = k2.RaggedTensor(y) use_cr_ctc = params.use_cr_ctc + use_sr_ctc = params.use_sr_ctc use_spec_aug = use_cr_ctc and is_training if use_spec_aug: supervision_intervals = batch["supervisions"] @@ -931,7 +946,7 @@ def compute_loss( supervision_segments = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -939,6 +954,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, use_cr_ctc=use_cr_ctc, + use_sr_ctc=use_sr_ctc, use_spec_aug=use_spec_aug, spec_augment=spec_augment, supervision_segments=supervision_segments, @@ -967,6 +983,8 @@ def compute_loss( loss += params.ctc_loss_scale * ctc_loss if use_cr_ctc: loss += params.cr_loss_scale * cr_loss + if use_sr_ctc: + loss += params.sr_loss_scale * sr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -985,8 +1003,10 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.use_cr_ctc: + if use_cr_ctc: info["cr_loss"] = cr_loss.detach().cpu().item() + if use_sr_ctc: + info["sr_loss"] = sr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()