From 924aa0b3bff0a31868c3cdd686f771c244b3e866 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 21 Aug 2023 14:46:44 -0400 Subject: [PATCH] [Do not merge] example of using LibriSpeech + Lhotse Shar --- egs/librispeech/ASR/prepare_shar.sh | 101 ++++ .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 528 +++++++++++++----- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 25 +- 3 files changed, 511 insertions(+), 143 deletions(-) create mode 100755 egs/librispeech/ASR/prepare_shar.sh diff --git a/egs/librispeech/ASR/prepare_shar.sh b/egs/librispeech/ASR/prepare_shar.sh new file mode 100755 index 000000000..30145fcb1 --- /dev/null +++ b/egs/librispeech/ASR/prepare_shar.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail +set -x + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +# Run data downloading and core manifest preparation +./prepare.sh --nj $nj --stage $stage --stop-stage 3 + +# Split the data into shards and compute the features on shard level +# This step leverages Lhotse Shar format for optimized sequential I/O +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: [Shar] Split manifests into shards and compute fbank features" + mkdir -p data/shar + if [ ! -e data/shar/.librispeech.done ]; then + for part in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do + lhotse cut simple \ + -r data/manifests/librispeech_recordings_${part}.jsonl.gz \ + -s data/manifests/librispeech_supervisions_${part}.jsonl.gz \ + data/manifests/librispeech_cuts_${part}.jsonl.gz + done + + lhotse combine \ + data/manifests/librispeech_cuts_train-{clean-100,clean-360,other-500}.jsonl.gz - \ + | shuf \ + | gzip -c \ + > data/manifests/librispeech_cuts_train_all.jsonl.gz + + lhotse shar export -j$nj -v -a flac -s 1000 \ + data/manifests/librispeech_cuts_train_all.jsonl.gz \ + data/shar + + lhotse shar compute-features -v -j$nj data/shar + + touch data/shar/.librispeech.done + fi + + if [ ! -e data/fbank/.librispeech-validated.done ]; then + log "Validating data/fbank for LibriSpeech" + parts=( + train-clean-100 + train-clean-360 + train-other-500 + test-clean + test-other + dev-clean + dev-other + ) + for part in ${parts[@]}; do + python3 ./local/validate_manifest.py \ + data/fbank/librispeech_cuts_${part}.jsonl.gz + done + touch data/fbank/.librispeech-validated.done + fi +fi + +# Run the rest of data preparation steps +./prepare.sh --stage $stage --stop-stage $stop_stage diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c47964b07..53db0d772 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -29,10 +29,12 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, + IterableDatasetWrapper, K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, SpecAugment, + make_worker_init_fn, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -52,6 +54,155 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) +def add_dataloading_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""Used only when --mini-libri is False.When enabled, + use 960h LibriSpeech. Otherwise, use 100h subset.""", + ) + group.add_argument( + "--mini-libri", + type=str2bool, + default=False, + help="True for mini librispeech", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples (returns audio + audio lens), or " + "OnTheFlyFeatures/PrecomputedFeatures (both return features + feature lens)", + ) + + group.add_argument( + "--shar-dir", + type=Path, + default=Path("data/shar"), + help="Path to directory with data in Lhotse Shar format (if used)", + ) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -75,145 +226,7 @@ class LibriSpeechAsrDataModule: @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="""Used only when --mini-libri is False.When enabled, - use 960h LibriSpeech. Otherwise, use 100h subset.""", - ) - group.add_argument( - "--mini-libri", - type=str2bool, - default=False, - help="True for mini librispeech", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) + return add_dataloading_arguments(parser) def train_dataloaders( self, @@ -473,3 +486,240 @@ class LibriSpeechAsrDataModule: return load_manifest_lazy( self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" ) + + +class LibriSpeechSharAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + if sampler_state_dict is not None: + logging.warning( + "Loading sampler state dict is not supported for Lhotse Shar -- ignoring this." + ) + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + if self.args.input_strategy == "OnTheFlyFeatures": + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + ) + else: + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + ) + + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train.repeat(), # sample infinite CutSet + max_duration=self.args.max_duration, + shuffle=True, + num_buckets=self.args.num_buckets, + # DDP auto-detection is disabled for Lhotse Shar + # instead, each worker process will initialize sampling + # with a different random seed using worker_init_fn, + # and CutSet.from_shar is going to react to this change. + rank=0, + world_size=1, + ) + logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + rank, world_size = None, None + if torch.distributed.is_initialized(): + rank, world_size = ( + torch.distributed.get_rank(), + torch.distributed.get_world_size(), + ) + + train_dl = DataLoader( + IterableDatasetWrapper(dataset=train, sampler=train_sampler), + num_workers=self.args.num_workers, + batch_size=None, + worker_init_fn=make_worker_init_fn( + rank=rank, world_size=world_size, seed=seed + ), + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.input_strategy == "OnTheFlyFeatures": + validate = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + ) + else: + validate = K2SpeechRecognitionDataset() + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.input_strategy == "OnTheFlyFeatures" + else eval(self.args.input_strategy)(), + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled and sharded cuts for full (960h) LibriSpeech using Lhotse Shar" + ) + # Below we'll figure out which files to read. + # Since we only use either (precomputed) features or recordings, + # we shouldn't iterate over both at the same time. + shar_dir = Path(self.args.shar_dir) + fields = {"cuts": sorted(shar_dir.glob("cuts.*.jsonl*"))} + if self.args.input_strategy == "PrecomputedFeatures": + logging.info( + "Requested PrecomputedFeatures, we'll only read features.XXXXXX.tar files." + ) + fields["features"] = sorted(shar_dir.glob("features.*.tar")) + else: # AudioSamples / OnTheFlyFeatures + logging.info( + f"Requested {self.args.input_strategy}, we'll only read recording.XXXXXX.tar files." + ) + fields["recording"] = sorted(shar_dir.glob("recording.*.tar")) + return CutSet.from_shar(fields=fields, shuffle_shards=True, seed="randomized") + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + raise NotImplementedError( + "LibriSpeech 100h subset support for Lhotse Shar is not implemented." + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + ) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 0aa1587ba..4a5fe4178 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -37,7 +37,11 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import ( + LibriSpeechAsrDataModule, + LibriSpeechSharAsrDataModule, + add_dataloading_arguments, +) from lhotse.cut import Cut from lhotse.utils import fix_random_seed from model import TdnnLstm @@ -112,6 +116,13 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--use-shar", + type=str2bool, + default=False, + help="Use Lhotse Shar data format for faster, sequential I/O. Requires running ./prepare_shar.sh first.", + ) + return parser @@ -555,7 +566,10 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) scheduler.load_state_dict(checkpoints["scheduler"]) - librispeech = LibriSpeechAsrDataModule(args) + if args.use_shar: + librispeech = LibriSpeechSharAsrDataModule(args) + else: + librispeech = LibriSpeechAsrDataModule(args) if params.full_libri: train_cuts = librispeech.train_all_shuf_cuts() @@ -584,7 +598,10 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + try: + train_dl.sampler.set_epoch(epoch) + except Exception: + pass # with Lhotse Shar the sampler won't have a set_epoch attribute if epoch > params.start_epoch: logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") @@ -628,7 +645,7 @@ def run(rank, world_size, args): def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + add_dataloading_arguments(parser) args = parser.parse_args() world_size = args.world_size