diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 7efb2b0d0..c4472ed23 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -216,7 +216,9 @@ class GigaSpeechAsrDataModule: ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, ) -> DataLoader: """ Args: @@ -358,10 +360,13 @@ class GigaSpeechAsrDataModule: ) else: validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, return_cuts=self.args.return_cuts, + cut_transforms=transforms, + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( - cuts_valid, max_duration=self.args.max_duration, shuffle=False, + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -383,11 +388,16 @@ class GigaSpeechAsrDataModule: return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False, + cuts, + max_duration=self.args.max_duration, + shuffle=False, ) logging.debug("About to create test dataloader") test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, ) return test_dl