Fix for black

This commit is contained in:
Yifan Yang 2023-10-17 20:07:32 +08:00
parent 6eddab2a8d
commit e71d0086cb

View File

@ -216,7 +216,9 @@ class GigaSpeechAsrDataModule:
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
@ -358,10 +360,13 @@ class GigaSpeechAsrDataModule:
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, return_cuts=self.args.return_cuts,
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid, max_duration=self.args.max_duration, shuffle=False,
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
@ -383,11 +388,16 @@ class GigaSpeechAsrDataModule:
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False,
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl