From 0f88a3a6c3c4051d3f8feb20ec1a207d504e2c53 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 30 May 2025 15:42:31 +0800 Subject: [PATCH] First working example --- .../asr_datamodule_with_parallel_aug.py | 45 ++++++++----- .../ASR/zipformer/train_with_aug.py | 64 +++++++++++++++---- 2 files changed, 83 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py b/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py index 8723b223e..7b76b63c0 100644 --- a/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py +++ b/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py @@ -28,7 +28,6 @@ import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy from lhotse.cut import Cut from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, @@ -57,21 +56,33 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) +""" +We use c.features = None below to suppress the following warnings + +2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a +DataCut that references pre-computed features. The feature manifest will be +detached, as we do not support feature-domain speed perturbation. +""" + + def perturb_speed(c: Cut): - factor = random.uniform(0.9, 1.1) - print("perturb_speed factor", factor) + factor = random.choice([0.9, 1.1]) + c.features = None + return lhotse.MonoCut.perturb_speed(c, factor) def perturb_volume(c: Cut): - factor = random.uniform(0.9, 1.1) - print("perturb_volume factor", factor) + factor = random.choice([0.9, 1.1]) + c.features = None + return lhotse.MonoCut.perturb_volume(c, factor) def perturb_tempo(c: Cut): - factor = random.uniform(0.9, 1.1) - print("perturb_tempo factor", factor) + factor = random.choice([0.9, 1.1]) + + c.features = None return lhotse.MonoCut.perturb_tempo(c, factor) @@ -86,7 +97,6 @@ class LibriSpeechAsrDataModuleWithParallelAug: experiments, e.g.: - dynamic batch size, - bucketing samplers, - - cut concatenation, - augmentation, - on-the-fly feature extraction @@ -112,6 +122,12 @@ class LibriSpeechAsrDataModuleWithParallelAug: help="""Used only when --mini-libri is False.When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.""", ) + group.add_argument( + "--enable-augmentation", + type=str2bool, + default=True, + help="True to enable augmentation for training set", + ) group.add_argument( "--mini-libri", type=str2bool, @@ -204,7 +220,12 @@ class LibriSpeechAsrDataModuleWithParallelAug: sampler_state_dict: The state dict for the training sampler. """ - transforms = [perturb_speed, perturb_volume, perturb_tempo] + if self.args.enable_augmentation: + logging.info("Augmentation is enabled") + transforms = [perturb_speed, perturb_volume, perturb_tempo] + else: + logging.info("Augmentation is disabled") + transforms = [] logging.info("About to create train dataset") train = ConsistencyRegularizationSpeechRecognitionDataset( @@ -254,12 +275,6 @@ class LibriSpeechAsrDataModuleWithParallelAug: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms logging.info("About to create dev dataset") if self.args.on_the_fly_feats: diff --git a/egs/librispeech/ASR/zipformer/train_with_aug.py b/egs/librispeech/ASR/zipformer/train_with_aug.py index bec9f96c6..71961a6dc 100755 --- a/egs/librispeech/ASR/zipformer/train_with_aug.py +++ b/egs/librispeech/ASR/zipformer/train_with_aug.py @@ -855,19 +855,38 @@ def compute_loss( disables autograd. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] + + feature_len_seq = [batch["supervisions"]["num_frames"]] + text_seq = list(batch["supervisions"]["text"]) + feature_seq = torch.nn.utils.rnn.unpad_sequence( + batch["inputs"], + batch["supervisions"]["num_frames"], + batch_first=True, + ) + + if "aug" in batch: + for aug in batch["aug"]: + feature_len_seq.append(aug["supervisions"]["num_frames"]) + text_seq.extend(aug["supervisions"]["text"]) + + feature_seq.extend( + torch.nn.utils.rnn.unpad_sequence( + aug["inputs"], + aug["supervisions"]["num_frames"], + batch_first=True, + ) + ) + + feature_lens = torch.cat(feature_len_seq).to(device) + feature = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True).to(device) + # at entry, feature is (N, T, C) assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) batch_idx_train = params.batch_idx_train warm_step = params.warm_step - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) + y = sp.encode(text_seq, out_type=int) y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): @@ -1029,6 +1048,9 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) + if "aug" in batch: + batch_size *= len(batch["aug"]) + 1 + try: with torch.cuda.amp.autocast( enabled=params.use_autocast, dtype=params.dtype @@ -1360,7 +1382,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if False and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, @@ -1443,12 +1465,32 @@ def display_and_save_batch( logging.info(f"Saving batch to {filename}") torch.save(batch, filename) - supervisions = batch["supervisions"] - features = batch["inputs"] + feature_len_seq = [batch["supervisions"]["num_frames"]] + text_seq = list(batch["supervisions"]["text"]) + feature_seq = torch.nn.utils.rnn.unpad_sequence( + batch["inputs"], + batch["supervisions"]["num_frames"], + batch_first=True, + ) + + if "aug" in batch: + for aug in batch["aug"]: + feature_len_seq.append(aug["supervisions"]["num_frames"]) + text_seq.extend(aug["supervisions"]["text"]) + + feature_seq.extend( + torch.nn.utils.rnn.unpad_sequence( + aug["inputs"], + aug["supervisions"]["num_frames"], + batch_first=True, + ) + ) + + features = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True) logging.info(f"features shape: {features.shape}") - y = sp.encode(supervisions["text"], out_type=int) + y = sp.encode(text_seq, out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}")