diff --git a/egs/mls_english/ASR/local/compute_fbank_musan.py b/egs/mls_english/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/mls_english/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/mls_english/ASR/local/utils/asr_datamodule.py b/egs/mls_english/ASR/local/utils/asr_datamodule.py index 250b40a63..6c6a1dd03 100644 --- a/egs/mls_english/ASR/local/utils/asr_datamodule.py +++ b/egs/mls_english/ASR/local/utils/asr_datamodule.py @@ -180,7 +180,10 @@ class MLSEnglishHFAsrDataModule: ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + cuts_musan: Optional[CutSet] = None, ) -> DataLoader: """ Args: @@ -191,6 +194,13 @@ class MLSEnglishHFAsrDataModule: """ transforms = [] + if cuts_musan is not None: + logging.info("Enable MUSAN") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") input_transforms = [] if self.args.enable_spec_aug: diff --git a/egs/mls_english/ASR/prepare.sh b/egs/mls_english/ASR/prepare.sh index c6582f679..78f169bd1 100755 --- a/egs/mls_english/ASR/prepare.sh +++ b/egs/mls_english/ASR/prepare.sh @@ -16,6 +16,14 @@ vocab_sizes=(2000) # You can add more sizes like (500 1000 2000) for comparison # Directory where dataset will be downloaded dl_dir=$PWD/download +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + . shared/parse_options.sh || exit 1 # All files generated by this script are saved in "data". @@ -32,7 +40,7 @@ log() { log "Starting MLS English data preparation" if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download MLS English dataset" + log "Stage 0: Download data" # Check if huggingface_hub is installed if ! python -c "import huggingface_hub" &> /dev/null; then log "huggingface_hub Python library not found. Installing it now..." @@ -55,6 +63,15 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then else log "Dataset already exists at $dl_dir/mls_english. Skipping download." fi + # If you ha`ve predownloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ] ; then + log "Downloading musan." + lhotse download musan $dl_dir + fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then @@ -73,7 +90,25 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare transcript for BPE training" + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + if [ ! -e data/manifests/.musan_prep.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_prep.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for musan" + if [ ! -e data/manifests/.musan_fbank.done ]; then + ./local/compute_fbank_musan.py + touch data/manifests/.musan_fbank.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare transcript for BPE training" if [ ! -f data/lang/transcript.txt ]; then log "Generating transcripts for BPE training" python local/utils/generate_transcript.py \ @@ -83,8 +118,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi fi -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare BPE tokenizer" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare BPE tokenizer" for vocab_size in ${vocab_sizes[@]}; do log "Training BPE model with vocab_size=${vocab_size}" bpe_dir=data/lang/bpe_${vocab_size} @@ -99,8 +134,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then done fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Show manifest statistics" +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Show manifest statistics" python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt cat data/manifests/manifest_statistics.txt fi diff --git a/egs/mls_english/ASR/zipformer/train.py b/egs/mls_english/ASR/zipformer/train.py index 7c6997656..cdc4bdad3 100755 --- a/egs/mls_english/ASR/zipformer/train.py +++ b/egs/mls_english/ASR/zipformer/train.py @@ -68,6 +68,7 @@ from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from lhotse import load_manifest from model import AsrModel from optim import Eden, ScaledAdam from scaling import ScheduledFloat @@ -1217,9 +1218,6 @@ def run(rank, world_size, args): mls_english_corpus = MLSEnglishHFAsrDataModule(args) mls_english_corpus.load_dataset(args.dataset_path) - # train_cuts = mls_english_corpus.train_cuts() - - # train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1227,16 +1225,24 @@ def run(rank, world_size, args): sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None + + if args.enable_musan: + musan_path = Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + if musan_path.exists(): + cuts_musan = load_manifest(musan_path) + logging.info(f"Loaded MUSAN manifest from {musan_path}") + else: + logging.warning(f"MUSAN manifest not found at {musan_path}, disabling MUSAN augmentation") + cuts_musan = None + else: + cuts_musan = None - # train_dl = mls_english_corpus.train_dataloaders( - # train_cuts, sampler_state_dict=sampler_state_dict - # ) + train_dl = mls_english_corpus.train_dataloader( - sampler_state_dict=sampler_state_dict + sampler_state_dict=sampler_state_dict, + cuts_musan=cuts_musan, ) - # valid_cuts = mls_english_corpus.valid_cuts() - # valid_dl = mls_english_corpus.valid_dataloader(valid_cuts) valid_dl = mls_english_corpus.valid_dataloader() if not params.print_diagnostics: