From 5ec7297f3251ee6f267b212e8095a57d215a18aa Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 29 May 2025 11:44:40 +0800 Subject: [PATCH] add dataset example for librispeech --- .../asr_datamodule_with_parallel_aug.py | 160 ++++-------------- 1 file changed, 34 insertions(+), 126 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py b/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py index 1b52aa8b5..8723b223e 100644 --- a/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py +++ b/egs/librispeech/ASR/zipformer/asr_datamodule_with_parallel_aug.py @@ -17,14 +17,16 @@ import argparse -import inspect import logging +import random from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import lhotse import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.cut import Cut from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -41,6 +43,9 @@ from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader +from icefall.speech_recognition_dataset import ( + ConsistencyRegularizationSpeechRecognitionDataset, +) from icefall.utils import str2bool @@ -52,7 +57,25 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) -class LibriSpeechAsrDataModule: +def perturb_speed(c: Cut): + factor = random.uniform(0.9, 1.1) + print("perturb_speed factor", factor) + return lhotse.MonoCut.perturb_speed(c, factor) + + +def perturb_volume(c: Cut): + factor = random.uniform(0.9, 1.1) + print("perturb_volume factor", factor) + return lhotse.MonoCut.perturb_volume(c, factor) + + +def perturb_tempo(c: Cut): + factor = random.uniform(0.9, 1.1) + print("perturb_tempo factor", factor) + return lhotse.MonoCut.perturb_tempo(c, factor) + + +class LibriSpeechAsrDataModuleWithParallelAug: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -123,36 +146,6 @@ class LibriSpeechAsrDataModule: help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) group.add_argument( "--shuffle", type=str2bool, @@ -184,28 +177,12 @@ class LibriSpeechAsrDataModule: ) group.add_argument( - "--enable-spec-aug", + "--on-the-fly-feats", type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. For training dataset, it always uses on_the_fly_feats", ) group.add_argument( @@ -227,83 +204,14 @@ class LibriSpeechAsrDataModule: sampler_state_dict: The state dict for the training sampler. """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") + transforms = [perturb_speed, perturb_volume, perturb_tempo] logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - input_strategy=eval(self.args.input_strategy)(), + train = ConsistencyRegularizationSpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), cut_transforms=transforms, - input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler(