diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index bd1ed26d8..61b53067a 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,15 +16,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import k2 import torch import torch.nn as nn +import torch.nn.functional as F from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, time_warp +from lhotse.dataset import SpecAugment class AsrModel(nn.Module): @@ -110,9 +112,8 @@ class AsrModel(nn.Module): if use_ctc: # Modules for CTC head self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), + nn.Dropout(p=0.1), # TODO: test removing this nn.Linear(encoder_dim, vocab_size), - nn.LogSoftmax(dim=-1), ) self.use_attention_decoder = use_attention_decoder @@ -158,28 +159,82 @@ class AsrModel(nn.Module): encoder_out_lens: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. + use_consistency_reg: bool = False, + use_smooth_reg: bool = False, + smooth_kernel: List[float] = [0.25, 0.5, 0.25], + eps: float = 1e-6, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. Args: encoder_out: - Encoder output, of shape (N, T, C). + Encoder output, of shape (N or 2 * N, T, C). encoder_out_lens: - Encoder output lengths, of shape (N,). + Encoder output lengths, of shape (N or 2 * N,). targets: - Target Tensor of shape (sum(target_lengths)). The targets are assumed + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed to be un-padded and concatenated within 1 dimension. + use_consistency_reg: + Whether use consistency regularization. + use_smooth_reg: + Whether use smooth regularization. """ - # Compute CTC log-prob - ctc_output = self.ctc_output(encoder_out) # (N, T, C) + ctc_output = self.ctc_output(encoder_out) # (N or 2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + if not use_smooth_reg: + ctc_log_probs = F.log_softmax(ctc_output, dim=-1) + else: + ctc_probs = ctc_output.softmax(dim=-1) # Used in sr_loss + ctc_log_probs = (ctc_probs + eps).log() + + # Compute CTC loss ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + log_probs=ctc_log_probs.permute(1, 0, 2), # (T, N or 2 * N, C) targets=targets.cpu(), input_lengths=encoder_out_lens.cpu(), target_lengths=target_lengths.cpu(), reduction="sum", ) - return ctc_loss + + if use_consistency_reg: + assert ctc_log_probs.shape[0] % 2 == 0 + # Compute cr_loss + exchanged_targets = ctc_log_probs.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_log_probs, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + else: + cr_loss = torch.empty(0) + + if use_smooth_reg: + # Hard code the kernel here, could try other values + assert len(smooth_kernel) == 3 and sum(smooth_kernel) == 1.0, smooth_kernel + smooth_kernel = torch.tensor(smooth_kernel, dtype=ctc_probs.dtype, + device=ctc_probs.device, requires_grad=False) + smooth_kernel = smooth_kernel.unsqueeze(0).unsqueeze(1).expand(ctc_probs.shape[-1], 1, 3) + # Now kernel: (C, 1, 3) + smoothed_ctc_probs = F.conv1d( + ctc_probs.detach().permute(0, 2, 1), # (N or 2 * N, C, T) + weight=smooth_kernel, stride=1, padding=0, groups=ctc_probs.shape[-1] + ).permute(0, 2, 1) # (N or 2 * N, T - 2, C) + sr_loss = nn.functional.kl_div( + input=ctc_log_probs[:, 1:-1], + target=(smoothed_ctc_probs + eps).log(), + reduction="none", + log_target=True, + ) # (N, T - 1 , C) + sr_loss = sr_loss.masked_fill(length_mask[:, 1:-1], 0.0).sum() + else: + sr_loss = torch.empty(0) + + return ctc_loss, cr_loss, sr_loss def forward_transducer( self, @@ -296,7 +351,13 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cr_ctc: bool = False, + use_sr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -316,9 +377,28 @@ class AsrModel(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + use_cr_ctc: + Whether use consistency-regularized CTC. + use_sr_ctc: + Whether use smooth-regularized CTC. + use_spec_aug: + Whether apply spec-augment manually, used only if use_cr_ctc is True. + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -334,6 +414,24 @@ class AsrModel(nn.Module): device = x.device + if use_cr_ctc: + assert self.use_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -351,6 +449,9 @@ class AsrModel(nn.Module): am_scale=am_scale, lm_scale=lm_scale, ) + if use_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) @@ -358,14 +459,23 @@ class AsrModel(nn.Module): if self.use_ctc: # Compute CTC loss targets = y.values - ctc_loss = self.forward_ctc( + ctc_loss, cr_loss, sr_loss = self.forward_ctc( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, targets=targets, target_lengths=y_lens, + use_consistency_reg=use_cr_ctc, + use_smooth_reg=use_sr_ctc, ) + if use_cr_ctc: + # We duplicate the batch when use_cr_ctc is True + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 + sr_loss = sr_loss * 0.5 else: ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) + sr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -374,7 +484,9 @@ class AsrModel(nn.Module): ys=y.to(device), ys_lens=y_lens.to(device), ) + if use_cr_ctc: + attention_decoder_loss = attention_decoder_loss * 0.5 else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9c1c7f5a7..ba1332300 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -72,6 +72,7 @@ from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -304,6 +305,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + + parser.add_argument( + "--use-sr-ctc", + type=str2bool, + default=False, + help="If True, use smooth-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -449,6 +464,27 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--sr-loss-scale", + type=float, + default=0.2, + help="Scale for smooth-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -717,6 +753,24 @@ def get_model(params: AttributeDict) -> nn.Module: return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -839,6 +893,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -855,8 +910,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -874,14 +929,36 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_sr_ctc = params.use_sr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss, sr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_sr_ctc=use_sr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, ) loss = 0.0 @@ -904,6 +981,10 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + if use_sr_ctc: + loss += params.sr_loss_scale * sr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -922,6 +1003,10 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() + if use_sr_ctc: + info["sr_loss"] = sr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -971,6 +1056,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -997,6 +1083,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1043,6 +1131,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,6 +1327,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1360,6 +1456,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1387,6 +1484,7 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1452,6 +1550,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1471,6 +1570,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954de..b0a42cefa 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -21,6 +21,7 @@ import argparse import collections import logging import os +import random import re import subprocess from collections import defaultdict @@ -38,6 +39,7 @@ import sentencepiece as spm import torch import torch.distributed as dist import torch.nn as nn +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -2271,3 +2273,41 @@ def num_tokens( if 0 in ans: num_tokens -= 1 return num_tokens + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + """Apply time warping on a batch of features + """ + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features