From dc74705d20f394918dda51486c1d8dfb0abae9bd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 29 May 2025 11:49:30 +0800 Subject: [PATCH] remove cr-loss --- .../ASR/zipformer/train_with_aug.py | 88 +------------------ 1 file changed, 3 insertions(+), 85 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train_with_aug.py b/egs/librispeech/ASR/zipformer/train_with_aug.py index f8864d58b..eb234bdcd 100755 --- a/egs/librispeech/ASR/zipformer/train_with_aug.py +++ b/egs/librispeech/ASR/zipformer/train_with_aug.py @@ -48,7 +48,6 @@ It supports training with: - transducer loss (default) - ctc loss - attention decoder loss - - cr-ctc loss (should use half the max-duration compared to regular ctc) """ @@ -66,7 +65,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule_with_parallel_aug import LibriSpeechAsrDataModuleWithParallelAug from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner @@ -304,13 +303,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) - parser.add_argument( - "--use-cr-ctc", - type=str2bool, - default=False, - help="If True, use consistency-regularized CTC.", - ) - def get_parser(): parser = argparse.ArgumentParser( @@ -456,20 +448,6 @@ def get_parser(): help="Scale for CTC loss.", ) - parser.add_argument( - "--cr-loss-scale", - type=float, - default=0.2, - help="Scale for consistency-regularization loss.", - ) - - parser.add_argument( - "--time-mask-ratio", - type=float, - default=2.5, - help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", - ) - parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -738,24 +716,6 @@ def get_model(params: AttributeDict) -> nn.Module: return model -def get_spec_augment(params: AttributeDict) -> SpecAugment: - num_frame_masks = int(10 * params.time_mask_ratio) - max_frames_mask_fraction = 0.15 * params.time_mask_ratio - logging.info( - f"num_frame_masks: {num_frame_masks}, " - f"max_frames_mask_fraction: {max_frames_mask_fraction}" - ) - spec_augment = SpecAugment( - time_warp_factor=0, # Do time warping in model.py - num_frame_masks=num_frame_masks, # default: 10 - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 - ) - return spec_augment - - def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -878,7 +838,6 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -895,8 +854,6 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - spec_augment: - The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -914,21 +871,6 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) - use_cr_ctc = params.use_cr_ctc - use_spec_aug = use_cr_ctc and is_training - if use_spec_aug: - supervision_intervals = batch["supervisions"] - supervision_segments = torch.stack( - [ - supervision_intervals["sequence_idx"], - supervision_intervals["start_frame"], - supervision_intervals["num_frames"], - ], - dim=1, - ) # shape: (S, 3) - else: - supervision_segments = None - with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, @@ -937,11 +879,6 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - use_cr_ctc=use_cr_ctc, - use_spec_aug=use_spec_aug, - spec_augment=spec_augment, - supervision_segments=supervision_segments, - time_warp_factor=params.spec_aug_time_warp_factor, ) loss = 0.0 @@ -964,8 +901,6 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss - if use_cr_ctc: - loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -984,8 +919,6 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.use_cr_ctc: - info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -1035,7 +968,6 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, - spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -1062,8 +994,6 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. - spec_augment: - The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1110,7 +1040,6 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1317,13 +1246,6 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - if params.use_cr_ctc: - assert params.use_ctc - assert not params.enable_spec_aug # we will do spec_augment in model.py - spec_augment = get_spec_augment(params) - else: - spec_augment = None - assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1369,7 +1291,7 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) + librispeech = LibriSpeechAsrDataModuleWithParallelAug(args) if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() @@ -1446,7 +1368,6 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, - spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1474,7 +1395,6 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, - spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1540,7 +1460,6 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, - spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1560,7 +1479,6 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() @@ -1582,7 +1500,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + LibriSpeechAsrDataModuleWithParallelAug.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)