mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
set world_size and rank explicitly
update
This commit is contained in:
parent
2420d0c95f
commit
c75767f600
@ -885,14 +885,21 @@ def run(rank, world_size, args):
|
|||||||
sampler_state_dict["max_duration"] = params.max_duration
|
sampler_state_dict["max_duration"] = params.max_duration
|
||||||
|
|
||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts,
|
||||||
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.use_aishell:
|
if params.use_aishell:
|
||||||
valid_cuts = multi_dataset.aishell_dev_cuts()
|
valid_cuts = multi_dataset.aishell_dev_cuts()
|
||||||
else:
|
else:
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
valid_cuts = multi_dataset.dev_cuts()
|
||||||
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
valid_dl = data_module.valid_dataloaders(
|
||||||
|
valid_cuts,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user