Fix for black

This commit is contained in:
Yifan Yang 2023-10-17 20:07:32 +08:00
parent 6eddab2a8d
commit e71d0086cb

View File

@ -216,7 +216,9 @@ class GigaSpeechAsrDataModule:
) )
def train_dataloaders( def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -358,10 +360,13 @@ class GigaSpeechAsrDataModule:
) )
else: else:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, return_cuts=self.args.return_cuts, cut_transforms=transforms,
return_cuts=self.args.return_cuts,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, max_duration=self.args.max_duration, shuffle=False, cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(
@ -383,11 +388,16 @@ class GigaSpeechAsrDataModule:
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False, cuts,
max_duration=self.args.max_duration,
shuffle=False,
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
) )
return test_dl return test_dl