suport for cascades

This commit is contained in:
yfyeung 2024-07-04 15:53:15 +08:00
parent 7235b8561b
commit 91488ce972

View File

@ -219,6 +219,8 @@ class LibriSpeechAsrDataModule:
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -314,6 +316,8 @@ class LibriSpeechAsrDataModule:
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -321,6 +325,8 @@ class LibriSpeechAsrDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -344,7 +350,12 @@ class LibriSpeechAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def valid_dataloaders(
self,
cuts_valid: CutSet,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -369,6 +380,8 @@ class LibriSpeechAsrDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(