diff --git a/egs/fisher_swbd/ASR/conformer_ctc/train.py b/egs/fisher_swbd/ASR/conformer_ctc/train.py index c1fa814c0..29f9f6cb6 100755 --- a/egs/fisher_swbd/ASR/conformer_ctc/train.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/train.py @@ -27,7 +27,7 @@ import k2 import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed from torch import Tensor @@ -620,17 +620,13 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) - librispeech = LibriSpeechAsrDataModule(args) + datamodule = AsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - train_dl = librispeech.train_dataloaders(train_cuts) + train_cuts = datamodule.train_cuts() + train_dl = datamodule.train_dataloaders(train_cuts) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = datamodule.dev_cuts() + valid_dl = datamodule.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom( model=model, diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index bf61b1be9..dc89dfdac 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -92,7 +92,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the LibriSpeech corpus # to $dl_dir/LibriSpeech mkdir -p data/manifests/fisher - lhotse prepare fisher-english $dl_dir data/manifests/fisher + lhotse prepare fisher-english --absolute-paths 1 $dl_dir data/manifests/fisher fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -100,7 +100,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # We assume that you have downloaded the LibriSpeech corpus # to $dl_dir/LibriSpeech mkdir -p data/manifests/swbd - lhotse prepare switchboard --omit-silence $dl_dir/LDC97S62 data/manifests/swbd + lhotse prepare switchboard --absolute-paths 1 --omit-silence $dl_dir/LDC97S62 data/manifests/swbd fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d03..40c8468f6 100644 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -20,14 +20,16 @@ import logging from functools import lru_cache from pathlib import Path -from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from tqdm import tqdm + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( BucketingSampler, - CutConcatenate, CutMix, + DynamicBucketingSampler, K2SpeechRecognitionDataset, + PerturbSpeed, PrecomputedFeatures, - SingleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -36,7 +38,12 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool -class LibriSpeechAsrDataModule: +class Resample16kHz: + def __call__(self, cuts: CutSet) -> CutSet: + return cuts.resample(16000).with_recording_path_prefix('download') + + +class AsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -66,17 +73,10 @@ class LibriSpeechAsrDataModule: "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", - ) group.add_argument( "--manifest-dir", type=Path, - default=Path("data/fbank"), + default=Path("data/manifests"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( @@ -86,13 +86,6 @@ class LibriSpeechAsrDataModule: help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) group.add_argument( "--num-buckets", type=int, @@ -100,32 +93,10 @@ class LibriSpeechAsrDataModule: help="The number of buckets for the BucketingSampler" "(you might want to increase it for larger datasets).", ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) group.add_argument( "--on-the-fly-feats", type=str2bool, - default=False, + default=True, help="When enabled, use on-the-fly cut mixing and feature " "extraction. Will drop existing precomputed feature manifests " "if available.", @@ -137,30 +108,15 @@ class LibriSpeechAsrDataModule: help="When enabled (=default), the examples will be " "shuffled for each epoch.", ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) group.add_argument( "--num-workers", type=int, - default=2, + default=8, help="The number of training dataloader workers that " "collect the batches.", ) - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - group.add_argument( "--spec-aug-time-warp-factor", type=int, @@ -171,52 +127,28 @@ class LibriSpeechAsrDataModule: "A value less than 1 means to disable time warp.", ) - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" + self.args.manifest_dir / "musan_cuts.jsonl.gz" ) - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + + train = K2SpeechRecognitionDataset( + input_strategy=input_strategy, + cut_transforms=[ + PerturbSpeed(factors=[0.9, 1.1], p=2 / 3, preserve_id=True), + Resample16kHz(), CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) - input_transforms.append( + ), + ], + input_transforms=[ SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=2, @@ -224,56 +156,19 @@ class LibriSpeechAsrDataModule: num_feature_masks=2, frames_mask_size=100, ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, + ], + return_cuts=True, ) - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) - if self.args.bucketing_sampler: - logging.info("Using BucketingSampler.") - train_sampler = BucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - bucket_method="equal_duration", - drop_last=True, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) logging.info("About to create train dataloader") - train_dl = DataLoader( train, sampler=train_sampler, @@ -285,39 +180,34 @@ class LibriSpeechAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms logging.info("About to create dev dataset") + input_strategy = PrecomputedFeatures() if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), ) + + validate = K2SpeechRecognitionDataset( + return_cuts=True, + input_strategy=input_strategy, + cut_transforms=[ + Resample16kHz(), + ], + ) + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) + logging.info("About to create dev dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, - num_workers=2, + num_workers=self.args.num_workers, persistent_workers=False, ) @@ -325,11 +215,19 @@ class LibriSpeechAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") + + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, + return_cuts=True, + input_strategy=input_strategy, + cut_transforms=[ + Resample16kHz(), + ], ) sampler = BucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False @@ -344,42 +242,44 @@ class LibriSpeechAsrDataModule: return test_dl @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" + def train_cuts(self) -> CutSet: + logging.info("About to get train Fisher + SWBD cuts") + return load_manifest_lazy( + self.args.manifest_dir + / "train_utterances_fisher-swbd_cuts.jsonl.gz" ) @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" + def dev_cuts(self) -> CutSet: + logging.info("About to get dev Fisher + SWBD cuts") + return load_manifest_lazy( + self.args.manifest_dir / "dev_utterances_fisher-swbd_cuts.jsonl.gz" ) @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") - - @lru_cache() - def test_clean_cuts(self) -> CutSet: + def test_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") - return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + raise NotImplemented - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") + +def test(): + parser = argparse.ArgumentParser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + adm = AsrDataModule(args) + + cuts = adm.train_cuts() + dl = adm.train_dataloaders(cuts) + for i, batch in tqdm(enumerate(dl)): + if i == 100: + break + + cuts = adm.dev_cuts() + dl = adm.valid_dataloaders(cuts) + for i, batch in tqdm(enumerate(dl)): + if i == 100: + break + + +if __name__ == '__main__': + test()