diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c47964b07..90f407a31 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -21,7 +21,7 @@ import inspect import logging from functools import lru_cache from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy @@ -33,6 +33,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures PrecomputedFeatures, SingleCutSampler, SpecAugment, + StatelessSampler, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -123,6 +124,12 @@ class LibriSpeechAsrDataModule: help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) + group.add_argument( + "--stateless-sampler", + type=str2bool, + default=False, + help="When enabled, the dataloader is completely stateless.", + ) group.add_argument( "--concatenate-cuts", type=str2bool, @@ -217,7 +224,7 @@ class LibriSpeechAsrDataModule: def train_dataloaders( self, - cuts_train: CutSet, + cuts_train: Union[CutSet, List[str]], sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ @@ -305,21 +312,40 @@ class LibriSpeechAsrDataModule: ) if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) + if self.args.stateless_sampler: + logging.info("Using bucketing StatelessSampler.") + train_sampler = StatelessSampler( + cuts_train, + base_seed=0, + index_path=self.args.manifest_dir / "files.idx", + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + ) + else: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) + if self.args.stateless_sampler: + logging.info("Using non-bucketing StatelessSampler.") + train_sampler = StatelessSampler( + cuts_train, + base_seed=0, + index_path=self.args.manifest_dir / "files.idx", + max_duration=self.args.max_duration, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) logging.info("About to create train dataloader") if sampler_state_dict is not None: @@ -379,7 +405,7 @@ class LibriSpeechAsrDataModule: return valid_dl - def test_dataloaders(self, cuts: CutSet) -> DataLoader: + def test_dataloaders(self, cuts_test: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) @@ -388,7 +414,7 @@ class LibriSpeechAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, + cuts_test, max_duration=self.args.max_duration, shuffle=False, ) @@ -402,42 +428,57 @@ class LibriSpeechAsrDataModule: return test_dl @lru_cache() - def train_clean_5_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get train-clean-5 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" - ) + def train_clean_5_cuts(self) -> Union[CutSet, List[str]]: + if self.args.stateless_sampler: + return [self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl"] + else: + logging.info("mini_librispeech: About to get train-clean-5 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + ) @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) + def train_clean_100_cuts(self) -> Union[CutSet, List[str]]: + if self.args.stateless_sampler: + return [self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl"] + else: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" - ) + def train_clean_360_cuts(self) -> Union[CutSet, List[str]]: + if self.args.stateless_sampler: + return [self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl"] + else: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + ) @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" - ) + def train_other_500_cuts(self) -> Union[CutSet, List[str]]: + if self.args.stateless_sampler: + return [self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl"] + else: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + ) @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) + def train_all_shuf_cuts(self) -> Union[CutSet, List[str]]: + if self.args.stateless_sampler: + return [self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl"] + else: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) @lru_cache() def dev_clean_2_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1ba..d5918be44 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -808,17 +808,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1217,7 +1216,8 @@ def run(rank, world_size, args): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) + if not params.stateless_sampler: + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1234,7 +1234,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if 0 and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, @@ -1251,7 +1251,8 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs + 1): scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) + if not params.stateless_sampler: + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)