diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index 82ba1abb3..edea3bdb9 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -885,14 +885,21 @@ def run(rank, world_size, args): sampler_state_dict["max_duration"] = params.max_duration train_dl = data_module.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict + train_cuts, + sampler_state_dict=sampler_state_dict, + world_size=world_size, + rank=rank, ) if params.use_aishell: valid_cuts = multi_dataset.aishell_dev_cuts() else: valid_cuts = multi_dataset.dev_cuts() - valid_dl = data_module.valid_dataloaders(valid_cuts) + valid_dl = data_module.valid_dataloaders( + valid_cuts, + world_size=world_size, + rank=rank, + ) if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")