From 91488ce972689822e813b7a1f916d8a86a242e72 Mon Sep 17 00:00:00 2001 From: yfyeung Date: Thu, 4 Jul 2024 15:53:15 +0800 Subject: [PATCH] suport for cascades --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 1b52aa8b5..6d635361a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -219,6 +219,8 @@ class LibriSpeechAsrDataModule: self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -314,6 +316,8 @@ class LibriSpeechAsrDataModule: buffer_size=self.args.num_buckets * 2000, shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -321,6 +325,8 @@ class LibriSpeechAsrDataModule: cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -344,7 +350,12 @@ class LibriSpeechAsrDataModule: return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -369,6 +380,8 @@ class LibriSpeechAsrDataModule: cuts_valid, max_duration=self.args.max_duration, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create dev dataloader") valid_dl = DataLoader(