mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
suport for cascades
This commit is contained in:
parent
7235b8561b
commit
91488ce972
@ -219,6 +219,8 @@ class LibriSpeechAsrDataModule:
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -314,6 +316,8 @@ class LibriSpeechAsrDataModule:
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
@ -321,6 +325,8 @@ class LibriSpeechAsrDataModule:
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
@ -344,7 +350,12 @@ class LibriSpeechAsrDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -369,6 +380,8 @@ class LibriSpeechAsrDataModule:
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
|
Loading…
x
Reference in New Issue
Block a user