set world_size and rank explicitly

update
This commit is contained in:
yfyeung 2025-05-10 17:32:14 +00:00
parent 2420d0c95f
commit c75767f600

View File

@ -885,14 +885,21 @@ def run(rank, world_size, args):
sampler_state_dict["max_duration"] = params.max_duration sampler_state_dict["max_duration"] = params.max_duration
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
) )
if params.use_aishell: if params.use_aishell:
valid_cuts = multi_dataset.aishell_dev_cuts() valid_cuts = multi_dataset.aishell_dev_cuts()
else: else:
valid_cuts = multi_dataset.dev_cuts() valid_cuts = multi_dataset.dev_cuts()
valid_dl = data_module.valid_dataloaders(valid_cuts) valid_dl = data_module.valid_dataloaders(
valid_cuts,
world_size=world_size,
rank=rank,
)
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")