mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 12:42:20 +00:00
suport for cascades
This commit is contained in:
parent
7235b8561b
commit
91488ce972
@ -219,6 +219,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -314,6 +316,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
buffer_size=self.args.num_buckets * 2000,
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -321,6 +325,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -344,7 +350,12 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_valid: CutSet,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
transforms = [
|
transforms = [
|
||||||
@ -369,6 +380,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user